使用trie优化效率
This commit is contained in:
parent
4a417107af
commit
5ff223f655
|
@ -16,6 +16,7 @@ MarkupSafe==2.1.5
|
||||||
multidict==6.0.5
|
multidict==6.0.5
|
||||||
pydantic==2.8.2
|
pydantic==2.8.2
|
||||||
pydantic_core==2.20.1
|
pydantic_core==2.20.1
|
||||||
|
pygtrie==2.5.0
|
||||||
PyMySQL==1.1.1
|
PyMySQL==1.1.1
|
||||||
python-multipart==0.0.9
|
python-multipart==0.0.9
|
||||||
sniffio==1.3.1
|
sniffio==1.3.1
|
||||||
|
|
|
@ -15,6 +15,9 @@ async def sendemail(to_addr, links):
|
||||||
print(links)
|
print(links)
|
||||||
html = "验证链接:[<a href=\""+links+"\">"+links+"</a>]"
|
html = "验证链接:[<a href=\""+links+"\">"+links+"</a>]"
|
||||||
|
|
||||||
|
if (SMTP_SRV == ''):
|
||||||
|
return 0
|
||||||
|
|
||||||
msg = MIMEText(html, 'html', 'utf-8')
|
msg = MIMEText(html, 'html', 'utf-8')
|
||||||
msg['From'] = ADDR
|
msg['From'] = ADDR
|
||||||
msg['To'] = to_addr
|
msg['To'] = to_addr
|
||||||
|
|
|
@ -14,7 +14,7 @@ import re
|
||||||
import asyncio
|
import asyncio
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from .cfg import config
|
from .cfg import config
|
||||||
|
import pygtrie
|
||||||
|
|
||||||
ALGORITHM = config["common"]["algorithm"]
|
ALGORITHM = config["common"]["algorithm"]
|
||||||
ACCESS_TOKEN_EXPIRE_MINUTES = config["common"]["access_token_expire_minutes"]
|
ACCESS_TOKEN_EXPIRE_MINUTES = config["common"]["access_token_expire_minutes"]
|
||||||
|
@ -25,23 +25,31 @@ MANAGE_KEY = config["common"]["manage_key"]
|
||||||
REDIRECT_URL_WHITELIST = config["common"]["redirect_url_whitelist"]
|
REDIRECT_URL_WHITELIST = config["common"]["redirect_url_whitelist"]
|
||||||
|
|
||||||
|
|
||||||
tokens = {}
|
tokens = pygtrie.StringTrie()
|
||||||
apikeys = {}
|
apikeys = pygtrie.StringTrie()
|
||||||
emails = {}
|
emails = pygtrie.StringTrie()
|
||||||
email_send_lst = []
|
email_send_lst = []
|
||||||
|
|
||||||
|
|
||||||
|
def prep_uuid(uuid: str):
|
||||||
|
return '/'.join(list(uuid))
|
||||||
|
|
||||||
|
|
||||||
|
def clean_uuid(uuid: str):
|
||||||
|
return uuid.replace("/", "")
|
||||||
|
|
||||||
|
|
||||||
async def clean_sys():
|
async def clean_sys():
|
||||||
while 1:
|
while 1:
|
||||||
await asyncio.sleep(CLEAN_TIMEOUT)
|
await asyncio.sleep(CLEAN_TIMEOUT)
|
||||||
for k in list(tokens.keys()):
|
for k, v in tokens.items():
|
||||||
if (tokens[k][1] < datetime.now()):
|
if (v[1] < datetime.now()):
|
||||||
del tokens[k]
|
del tokens[k]
|
||||||
for k in list(apikeys.keys()):
|
for k, v in apikeys.items():
|
||||||
if (apikeys[k][1] < datetime.now()):
|
if (v[1] < datetime.now()):
|
||||||
del apikeys[k]
|
del apikeys[k]
|
||||||
for k in list(emails.keys()):
|
for k, v in emails.items():
|
||||||
if (emails[k][2] < datetime.now()):
|
if (v[2] < datetime.now()):
|
||||||
del emails[k]
|
del emails[k]
|
||||||
|
|
||||||
|
|
||||||
|
@ -89,7 +97,7 @@ async def authenticate_user(username: str, password: str):
|
||||||
|
|
||||||
|
|
||||||
async def create_token(username: str):
|
async def create_token(username: str):
|
||||||
tkn = uuid.uuid4().hex
|
tkn = prep_uuid(uuid.uuid4().hex)
|
||||||
tokens[tkn] = (username, datetime.now() +
|
tokens[tkn] = (username, datetime.now() +
|
||||||
timedelta(minutes=float(ACCESS_TOKEN_EXPIRE_MINUTES)))
|
timedelta(minutes=float(ACCESS_TOKEN_EXPIRE_MINUTES)))
|
||||||
return tkn
|
return tkn
|
||||||
|
@ -119,9 +127,9 @@ async def check_apikey(tkn: str):
|
||||||
async def login_callback(response: Response, username: str = Form(), password: str = Form()):
|
async def login_callback(response: Response, username: str = Form(), password: str = Form()):
|
||||||
if (await authenticate_user(username, password)):
|
if (await authenticate_user(username, password)):
|
||||||
tokennow = await create_token(username)
|
tokennow = await create_token(username)
|
||||||
tkn = uuid.uuid4().hex
|
tkn = prep_uuid(uuid.uuid4().hex)
|
||||||
apikeys[tkn] = tokens[tokennow]
|
apikeys[tkn] = tokens[tokennow]
|
||||||
response.set_cookie("session", tokennow)
|
response.set_cookie("session", clean_uuid(tokennow))
|
||||||
return {"msg": "", "key": tkn}
|
return {"msg": "", "key": tkn}
|
||||||
else:
|
else:
|
||||||
return {"msg": "用户名或密码错误", "key": ""}
|
return {"msg": "用户名或密码错误", "key": ""}
|
||||||
|
@ -138,12 +146,13 @@ async def login_callback(username: str = Form(), password: str = Form(), email:
|
||||||
if (not re.fullmatch(regex, email)):
|
if (not re.fullmatch(regex, email)):
|
||||||
return {"msg": "邮箱不合法", "code": 1}
|
return {"msg": "邮箱不合法", "code": 1}
|
||||||
if not (await db.check_user(username)):
|
if not (await db.check_user(username)):
|
||||||
tkn = uuid.uuid4().hex
|
tkn = prep_uuid(uuid.uuid4().hex)
|
||||||
emails[tkn] = (username, hashlib.sha256(
|
emails[tkn] = (username, hashlib.sha256(
|
||||||
password.encode("utf-8")).hexdigest(), datetime.now() +
|
password.encode("utf-8")).hexdigest(), datetime.now() +
|
||||||
timedelta(minutes=float(ACCESS_EMAIL_EXPIRE_MINUTES)), email)
|
timedelta(minutes=float(ACCESS_EMAIL_EXPIRE_MINUTES)), email)
|
||||||
|
|
||||||
email_send_lst.append((email, ROOT+"/api/checkemail?uid="+tkn))
|
email_send_lst.append(
|
||||||
|
(email, ROOT+"/api/checkemail?uid="+clean_uuid(tkn)))
|
||||||
return {"msg": "验证邮件已发送到邮箱,请在10分钟内完成验证", "code": 0}
|
return {"msg": "验证邮件已发送到邮箱,请在10分钟内完成验证", "code": 0}
|
||||||
else:
|
else:
|
||||||
return {"msg": "用户名重复", "code": 1}
|
return {"msg": "用户名重复", "code": 1}
|
||||||
|
@ -151,6 +160,7 @@ async def login_callback(username: str = Form(), password: str = Form(), email:
|
||||||
|
|
||||||
@app.get("/api/checkemail")
|
@app.get("/api/checkemail")
|
||||||
async def checkemail(uid: str):
|
async def checkemail(uid: str):
|
||||||
|
uid = prep_uuid(uid)
|
||||||
if (uid not in emails):
|
if (uid not in emails):
|
||||||
return templates.TemplateResponse("checkemail.html", {"request": {}, "msg": "不存在的注册id"})
|
return templates.TemplateResponse("checkemail.html", {"request": {}, "msg": "不存在的注册id"})
|
||||||
if (emails[uid][2] < datetime.now()):
|
if (emails[uid][2] < datetime.now()):
|
||||||
|
@ -168,6 +178,7 @@ async def checkemail(uid: str):
|
||||||
|
|
||||||
@app.get("/api/resetpasswd", response_class=HTMLResponse)
|
@app.get("/api/resetpasswd", response_class=HTMLResponse)
|
||||||
async def resetpasswd(uid: str, response: Response):
|
async def resetpasswd(uid: str, response: Response):
|
||||||
|
uid = prep_uuid(uid)
|
||||||
if (uid not in emails):
|
if (uid not in emails):
|
||||||
return templates.TemplateResponse("checkemail.html", {"request": {}, "msg": "不存在的验证id"})
|
return templates.TemplateResponse("checkemail.html", {"request": {}, "msg": "不存在的验证id"})
|
||||||
if (emails[uid][2] < datetime.now()):
|
if (emails[uid][2] < datetime.now()):
|
||||||
|
@ -176,9 +187,9 @@ async def resetpasswd(uid: str, response: Response):
|
||||||
if (emails[uid][1] != ""):
|
if (emails[uid][1] != ""):
|
||||||
return templates.TemplateResponse("checkemail.html", {"request": {}, "msg": "不存在的注册id"})
|
return templates.TemplateResponse("checkemail.html", {"request": {}, "msg": "不存在的注册id"})
|
||||||
tokennow = await create_token(emails[uid][0])
|
tokennow = await create_token(emails[uid][0])
|
||||||
tkn = uuid.uuid4().hex
|
tkn = prep_uuid(uuid.uuid4().hex)
|
||||||
apikeys[tkn] = tokens[tokennow]
|
apikeys[tkn] = tokens[tokennow]
|
||||||
response.set_cookie("session", tokennow)
|
response.set_cookie("session", clean_uuid(tokennow))
|
||||||
del emails[uid]
|
del emails[uid]
|
||||||
return '<html><head><meta http-equiv="refresh" content="0;url=/user"><title>正在跳转</title></head><body>正在跳转</body></html>'
|
return '<html><head><meta http-equiv="refresh" content="0;url=/user"><title>正在跳转</title></head><body>正在跳转</body></html>'
|
||||||
|
|
||||||
|
@ -187,11 +198,12 @@ async def resetpasswd(uid: str, response: Response):
|
||||||
async def resetpasswd(username: str = Form()):
|
async def resetpasswd(username: str = Form()):
|
||||||
if (await db.check_user(username)):
|
if (await db.check_user(username)):
|
||||||
email = await db.get_email(username)
|
email = await db.get_email(username)
|
||||||
tkn = uuid.uuid4().hex
|
tkn = prep_uuid(uuid.uuid4().hex)
|
||||||
emails[tkn] = (username, "", datetime.now() +
|
emails[tkn] = (username, "", datetime.now() +
|
||||||
timedelta(minutes=float(ACCESS_EMAIL_EXPIRE_MINUTES)), email)
|
timedelta(minutes=float(ACCESS_EMAIL_EXPIRE_MINUTES)), email)
|
||||||
|
|
||||||
email_send_lst.append((email, ROOT+"/api/resetpasswd?uid="+tkn))
|
email_send_lst.append(
|
||||||
|
(email, ROOT+"/api/resetpasswd?uid="+clean_uuid(tkn)))
|
||||||
return {"msg": "验证邮件已发送到邮箱,请在10分钟内完成验证", "code": 0}
|
return {"msg": "验证邮件已发送到邮箱,请在10分钟内完成验证", "code": 0}
|
||||||
else:
|
else:
|
||||||
return {"msg": "用户名不存在", "code": 1}
|
return {"msg": "用户名不存在", "code": 1}
|
||||||
|
@ -199,6 +211,7 @@ async def resetpasswd(username: str = Form()):
|
||||||
|
|
||||||
@app.get("/api/getinfo")
|
@app.get("/api/getinfo")
|
||||||
async def get_user_info(uid: str):
|
async def get_user_info(uid: str):
|
||||||
|
uid = prep_uuid(uid)
|
||||||
username = await check_apikey(uid)
|
username = await check_apikey(uid)
|
||||||
if (username == ""):
|
if (username == ""):
|
||||||
return {"code": 1, "msg": "token无效", "data": {}}
|
return {"code": 1, "msg": "token无效", "data": {}}
|
||||||
|
@ -210,6 +223,7 @@ async def changepasswd(password: str = Form(), session: Annotated[str | None, Co
|
||||||
if (check_passwd(password)):
|
if (check_passwd(password)):
|
||||||
return {"msg": "密码强度弱,请至少包含一个大小写字符与数字且长度>8"}
|
return {"msg": "密码强度弱,请至少包含一个大小写字符与数字且长度>8"}
|
||||||
if (session is not None):
|
if (session is not None):
|
||||||
|
session = prep_uuid(session)
|
||||||
username = await check_token(session)
|
username = await check_token(session)
|
||||||
if (username != ""):
|
if (username != ""):
|
||||||
await db.update_passwd(username, hashlib.sha256(
|
await db.update_passwd(username, hashlib.sha256(
|
||||||
|
@ -229,9 +243,10 @@ async def login(state: str = "", client_id: str = "", redirect_url: str = "/user
|
||||||
if (now_redirect_url not in REDIRECT_URL_WHITELIST):
|
if (now_redirect_url not in REDIRECT_URL_WHITELIST):
|
||||||
return templates.TemplateResponse("checkemail.html", {"request": {}, "msg": "无效的重定向URL"})
|
return templates.TemplateResponse("checkemail.html", {"request": {}, "msg": "无效的重定向URL"})
|
||||||
if (session is not None):
|
if (session is not None):
|
||||||
|
session = prep_uuid(session)
|
||||||
username = await check_token(session)
|
username = await check_token(session)
|
||||||
if (username != ""):
|
if (username != ""):
|
||||||
tkn = uuid.uuid4().hex
|
tkn = prep_uuid(uuid.uuid4().hex)
|
||||||
apikeys[tkn] = tokens[session]
|
apikeys[tkn] = tokens[session]
|
||||||
return RedirectResponse(url=redirect_url+f"#access_token={tkn}&token_type=Bearer&state={state}")
|
return RedirectResponse(url=redirect_url+f"#access_token={tkn}&token_type=Bearer&state={state}")
|
||||||
return templates.TemplateResponse("login.html", {"request": {}, "redirect_url": redirect_url, "state": state})
|
return templates.TemplateResponse("login.html", {"request": {}, "redirect_url": redirect_url, "state": state})
|
||||||
|
@ -246,6 +261,7 @@ async def login(redirect_url: str):
|
||||||
async def login(session: Annotated[str | None, Cookie()] = None):
|
async def login(session: Annotated[str | None, Cookie()] = None):
|
||||||
|
|
||||||
if (session is not None):
|
if (session is not None):
|
||||||
|
session = prep_uuid(session)
|
||||||
username = await check_token(session)
|
username = await check_token(session)
|
||||||
if (username == ""):
|
if (username == ""):
|
||||||
return RedirectResponse(url="/login?redirect_url=/user")
|
return RedirectResponse(url="/login?redirect_url=/user")
|
||||||
|
|
Loading…
Reference in New Issue