From 5ff223f6559472d18bf018cd3690cee128201ea6 Mon Sep 17 00:00:00 2001 From: cxykevin Date: Sun, 25 Aug 2024 20:05:24 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BD=BF=E7=94=A8trie=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E6=95=88=E7=8E=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- requirements.txt | 1 + server/email.py | 3 +++ server/main.py | 56 +++++++++++++++++++++++++++++++----------------- 3 files changed, 40 insertions(+), 20 deletions(-) diff --git a/requirements.txt b/requirements.txt index 63e1ed5..22f7fb7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,6 +16,7 @@ MarkupSafe==2.1.5 multidict==6.0.5 pydantic==2.8.2 pydantic_core==2.20.1 +pygtrie==2.5.0 PyMySQL==1.1.1 python-multipart==0.0.9 sniffio==1.3.1 diff --git a/server/email.py b/server/email.py index fc65eaf..7d04635 100644 --- a/server/email.py +++ b/server/email.py @@ -15,6 +15,9 @@ async def sendemail(to_addr, links): print(links) html = "验证链接:["+links+"]" + if (SMTP_SRV == ''): + return 0 + msg = MIMEText(html, 'html', 'utf-8') msg['From'] = ADDR msg['To'] = to_addr diff --git a/server/main.py b/server/main.py index bd64d40..ccb2a90 100644 --- a/server/main.py +++ b/server/main.py @@ -14,7 +14,7 @@ import re import asyncio import uvicorn from .cfg import config - +import pygtrie ALGORITHM = config["common"]["algorithm"] ACCESS_TOKEN_EXPIRE_MINUTES = config["common"]["access_token_expire_minutes"] @@ -25,23 +25,31 @@ MANAGE_KEY = config["common"]["manage_key"] REDIRECT_URL_WHITELIST = config["common"]["redirect_url_whitelist"] -tokens = {} -apikeys = {} -emails = {} +tokens = pygtrie.StringTrie() +apikeys = pygtrie.StringTrie() +emails = pygtrie.StringTrie() email_send_lst = [] +def prep_uuid(uuid: str): + return '/'.join(list(uuid)) + + +def clean_uuid(uuid: str): + return uuid.replace("/", "") + + async def clean_sys(): while 1: await asyncio.sleep(CLEAN_TIMEOUT) - for k in list(tokens.keys()): - if (tokens[k][1] < datetime.now()): + for k, v in tokens.items(): + if (v[1] < datetime.now()): del tokens[k] - for k in list(apikeys.keys()): - if (apikeys[k][1] < datetime.now()): + for k, v in apikeys.items(): + if (v[1] < datetime.now()): del apikeys[k] - for k in list(emails.keys()): - if (emails[k][2] < datetime.now()): + for k, v in emails.items(): + if (v[2] < datetime.now()): del emails[k] @@ -89,7 +97,7 @@ async def authenticate_user(username: str, password: str): async def create_token(username: str): - tkn = uuid.uuid4().hex + tkn = prep_uuid(uuid.uuid4().hex) tokens[tkn] = (username, datetime.now() + timedelta(minutes=float(ACCESS_TOKEN_EXPIRE_MINUTES))) return tkn @@ -119,9 +127,9 @@ async def check_apikey(tkn: str): async def login_callback(response: Response, username: str = Form(), password: str = Form()): if (await authenticate_user(username, password)): tokennow = await create_token(username) - tkn = uuid.uuid4().hex + tkn = prep_uuid(uuid.uuid4().hex) apikeys[tkn] = tokens[tokennow] - response.set_cookie("session", tokennow) + response.set_cookie("session", clean_uuid(tokennow)) return {"msg": "", "key": tkn} else: return {"msg": "用户名或密码错误", "key": ""} @@ -138,12 +146,13 @@ async def login_callback(username: str = Form(), password: str = Form(), email: if (not re.fullmatch(regex, email)): return {"msg": "邮箱不合法", "code": 1} if not (await db.check_user(username)): - tkn = uuid.uuid4().hex + tkn = prep_uuid(uuid.uuid4().hex) emails[tkn] = (username, hashlib.sha256( password.encode("utf-8")).hexdigest(), datetime.now() + timedelta(minutes=float(ACCESS_EMAIL_EXPIRE_MINUTES)), email) - email_send_lst.append((email, ROOT+"/api/checkemail?uid="+tkn)) + email_send_lst.append( + (email, ROOT+"/api/checkemail?uid="+clean_uuid(tkn))) return {"msg": "验证邮件已发送到邮箱,请在10分钟内完成验证", "code": 0} else: return {"msg": "用户名重复", "code": 1} @@ -151,6 +160,7 @@ async def login_callback(username: str = Form(), password: str = Form(), email: @app.get("/api/checkemail") async def checkemail(uid: str): + uid = prep_uuid(uid) if (uid not in emails): return templates.TemplateResponse("checkemail.html", {"request": {}, "msg": "不存在的注册id"}) if (emails[uid][2] < datetime.now()): @@ -168,6 +178,7 @@ async def checkemail(uid: str): @app.get("/api/resetpasswd", response_class=HTMLResponse) async def resetpasswd(uid: str, response: Response): + uid = prep_uuid(uid) if (uid not in emails): return templates.TemplateResponse("checkemail.html", {"request": {}, "msg": "不存在的验证id"}) if (emails[uid][2] < datetime.now()): @@ -176,9 +187,9 @@ async def resetpasswd(uid: str, response: Response): if (emails[uid][1] != ""): return templates.TemplateResponse("checkemail.html", {"request": {}, "msg": "不存在的注册id"}) tokennow = await create_token(emails[uid][0]) - tkn = uuid.uuid4().hex + tkn = prep_uuid(uuid.uuid4().hex) apikeys[tkn] = tokens[tokennow] - response.set_cookie("session", tokennow) + response.set_cookie("session", clean_uuid(tokennow)) del emails[uid] return '正在跳转正在跳转' @@ -187,11 +198,12 @@ async def resetpasswd(uid: str, response: Response): async def resetpasswd(username: str = Form()): if (await db.check_user(username)): email = await db.get_email(username) - tkn = uuid.uuid4().hex + tkn = prep_uuid(uuid.uuid4().hex) emails[tkn] = (username, "", datetime.now() + timedelta(minutes=float(ACCESS_EMAIL_EXPIRE_MINUTES)), email) - email_send_lst.append((email, ROOT+"/api/resetpasswd?uid="+tkn)) + email_send_lst.append( + (email, ROOT+"/api/resetpasswd?uid="+clean_uuid(tkn))) return {"msg": "验证邮件已发送到邮箱,请在10分钟内完成验证", "code": 0} else: return {"msg": "用户名不存在", "code": 1} @@ -199,6 +211,7 @@ async def resetpasswd(username: str = Form()): @app.get("/api/getinfo") async def get_user_info(uid: str): + uid = prep_uuid(uid) username = await check_apikey(uid) if (username == ""): return {"code": 1, "msg": "token无效", "data": {}} @@ -210,6 +223,7 @@ async def changepasswd(password: str = Form(), session: Annotated[str | None, Co if (check_passwd(password)): return {"msg": "密码强度弱,请至少包含一个大小写字符与数字且长度>8"} if (session is not None): + session = prep_uuid(session) username = await check_token(session) if (username != ""): await db.update_passwd(username, hashlib.sha256( @@ -229,9 +243,10 @@ async def login(state: str = "", client_id: str = "", redirect_url: str = "/user if (now_redirect_url not in REDIRECT_URL_WHITELIST): return templates.TemplateResponse("checkemail.html", {"request": {}, "msg": "无效的重定向URL"}) if (session is not None): + session = prep_uuid(session) username = await check_token(session) if (username != ""): - tkn = uuid.uuid4().hex + tkn = prep_uuid(uuid.uuid4().hex) apikeys[tkn] = tokens[session] return RedirectResponse(url=redirect_url+f"#access_token={tkn}&token_type=Bearer&state={state}") return templates.TemplateResponse("login.html", {"request": {}, "redirect_url": redirect_url, "state": state}) @@ -246,6 +261,7 @@ async def login(redirect_url: str): async def login(session: Annotated[str | None, Cookie()] = None): if (session is not None): + session = prep_uuid(session) username = await check_token(session) if (username == ""): return RedirectResponse(url="/login?redirect_url=/user")