270 lines
9.2 KiB
Python
270 lines
9.2 KiB
Python
from fastapi.security import OAuth2PasswordBearer
|
||
from fastapi import FastAPI, Cookie, Response, Form
|
||
from fastapi.templating import Jinja2Templates
|
||
from fastapi.responses import RedirectResponse, HTMLResponse
|
||
from datetime import timedelta, datetime
|
||
from contextlib import asynccontextmanager
|
||
from . import db
|
||
from . import email as email_sys
|
||
from . import reg
|
||
import hashlib
|
||
import uuid
|
||
from typing import Annotated
|
||
import re
|
||
import asyncio
|
||
import uvicorn
|
||
from .cfg import config
|
||
|
||
|
||
ALGORITHM = config["common"]["algorithm"]
|
||
ACCESS_TOKEN_EXPIRE_MINUTES = config["common"]["access_token_expire_minutes"]
|
||
ACCESS_EMAIL_EXPIRE_MINUTES = config["common"]["access_email_expire_minutes"]
|
||
ROOT = config["common"]["root"]
|
||
CLEAN_TIMEOUT = config["common"]["clean_timeout"]
|
||
MANAGE_KEY = config["common"]["manage_key"]
|
||
REDIRECT_URL_WHITELIST = config["common"]["redirect_url_whitelist"]
|
||
|
||
|
||
tokens = {}
|
||
apikeys = {}
|
||
emails = {}
|
||
email_send_lst = []
|
||
|
||
|
||
async def clean_sys():
|
||
while 1:
|
||
await asyncio.sleep(CLEAN_TIMEOUT)
|
||
for k in list(tokens.keys()):
|
||
if (tokens[k][1] < datetime.now()):
|
||
del tokens[k]
|
||
for k in list(apikeys.keys()):
|
||
if (apikeys[k][1] < datetime.now()):
|
||
del apikeys[k]
|
||
for k in list(emails.keys()):
|
||
if (emails[k][2] < datetime.now()):
|
||
del emails[k]
|
||
|
||
|
||
async def send_email():
|
||
while 1:
|
||
if (len(email_send_lst) > 0):
|
||
infos = email_send_lst.pop(0)
|
||
asyncio.create_task(email_sys.sendemail(infos[0], infos[1]))
|
||
await asyncio.sleep(0.1)
|
||
else:
|
||
await asyncio.sleep(5)
|
||
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
await db.connect_db()
|
||
asyncio.create_task(clean_sys())
|
||
asyncio.create_task(send_email())
|
||
yield
|
||
|
||
app = FastAPI(lifespan=lifespan)
|
||
|
||
templates = Jinja2Templates(directory="src")
|
||
|
||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||
|
||
|
||
def check_passwd(passwd: str):
|
||
if (len(passwd) < 8):
|
||
return 1
|
||
pattern = r'^(?![a-zA-Z]+$)(?!\d+$)(?![^\da-zA-Z\s]+$).{8,40}$'
|
||
|
||
if re.match(pattern, passwd):
|
||
return 0
|
||
else:
|
||
return 1
|
||
|
||
|
||
async def authenticate_user(username: str, password: str):
|
||
hashed_password = await db.get_user(username)
|
||
if not hashed_password:
|
||
return False
|
||
return hashed_password == hashlib.sha256(
|
||
password.encode("utf-8")).hexdigest()
|
||
|
||
|
||
async def create_token(username: str):
|
||
tkn = uuid.uuid4().hex
|
||
tokens[tkn] = (username, datetime.now() +
|
||
timedelta(minutes=float(ACCESS_TOKEN_EXPIRE_MINUTES)))
|
||
return tkn
|
||
|
||
|
||
async def check_token(tkn: str):
|
||
res = tokens.get(tkn, None)
|
||
if (res is None):
|
||
return ""
|
||
if (res[1] < datetime.now()):
|
||
del tokens[tkn]
|
||
return ""
|
||
return res[0]
|
||
|
||
|
||
async def check_apikey(tkn: str):
|
||
res = apikeys.get(tkn, None)
|
||
if (res is None):
|
||
return ""
|
||
if (res[1] < datetime.now()):
|
||
del apikeys[tkn]
|
||
return ""
|
||
return res[0]
|
||
|
||
|
||
@app.post("/api/login")
|
||
async def login_callback(response: Response, username: str = Form(), password: str = Form()):
|
||
if (await authenticate_user(username, password)):
|
||
tokennow = await create_token(username)
|
||
tkn = uuid.uuid4().hex
|
||
apikeys[tkn] = tokens[tokennow]
|
||
response.set_cookie("session", tokennow)
|
||
return {"msg": "", "key": tkn}
|
||
else:
|
||
return {"msg": "用户名或密码错误", "key": ""}
|
||
|
||
|
||
regex = re.compile(
|
||
r'([A-Za-z0-9]+[.-_])*[A-Za-z0-9]+@[A-Za-z0-9-]+(\.[A-Z|a-z]{2,})+')
|
||
|
||
|
||
@app.post("/api/signup")
|
||
async def login_callback(username: str = Form(), password: str = Form(), email: str = Form()):
|
||
if (check_passwd(password)):
|
||
return {"msg": "密码强度弱,请至少包含一个大小写字符与数字且长度>8", "code": 1}
|
||
if (not re.fullmatch(regex, email)):
|
||
return {"msg": "邮箱不合法", "code": 1}
|
||
if not (await db.check_user(username)):
|
||
tkn = uuid.uuid4().hex
|
||
emails[tkn] = (username, hashlib.sha256(
|
||
password.encode("utf-8")).hexdigest(), datetime.now() +
|
||
timedelta(minutes=float(ACCESS_EMAIL_EXPIRE_MINUTES)), email)
|
||
|
||
email_send_lst.append((email, ROOT+"/api/checkemail?uid="+tkn))
|
||
return {"msg": "验证邮件已发送到邮箱,请在10分钟内完成验证", "code": 0}
|
||
else:
|
||
return {"msg": "用户名重复", "code": 1}
|
||
|
||
|
||
@app.get("/api/checkemail")
|
||
async def checkemail(uid: str):
|
||
if (uid not in emails):
|
||
return templates.TemplateResponse("checkemail.html", {"request": {}, "msg": "不存在的注册id"})
|
||
if (emails[uid][2] < datetime.now()):
|
||
del emails[uid]
|
||
return templates.TemplateResponse("checkemail.html", {"request": {}, "msg": "链接已过期"})
|
||
if (emails[uid][1] == ""):
|
||
return templates.TemplateResponse("checkemail.html", {"request": {}, "msg": "不存在的注册id"})
|
||
if await db.create_user(emails[uid][0], emails[uid][1], emails[uid][3]) == 0:
|
||
del emails[uid]
|
||
return templates.TemplateResponse("checkemail.html", {"request": {}, "msg": "创建成功"})
|
||
else:
|
||
del emails[uid]
|
||
return templates.TemplateResponse("checkemail.html", {"request": {}, "msg": "重复的用户名"})
|
||
|
||
|
||
@app.get("/api/resetpasswd", response_class=HTMLResponse)
|
||
async def resetpasswd(uid: str, response: Response):
|
||
if (uid not in emails):
|
||
return templates.TemplateResponse("checkemail.html", {"request": {}, "msg": "不存在的验证id"})
|
||
if (emails[uid][2] < datetime.now()):
|
||
del emails[uid]
|
||
return templates.TemplateResponse("checkemail.html", {"request": {}, "msg": "链接已过期"})
|
||
if (emails[uid][1] != ""):
|
||
return templates.TemplateResponse("checkemail.html", {"request": {}, "msg": "不存在的注册id"})
|
||
tokennow = await create_token(emails[uid][0])
|
||
tkn = uuid.uuid4().hex
|
||
apikeys[tkn] = tokens[tokennow]
|
||
response.set_cookie("session", tokennow)
|
||
return '<html><head><meta http-equiv="refresh" content="0;url=/user"><title>正在跳转</title></head><body>正在跳转</body></html>'
|
||
|
||
|
||
@app.post("/api/send_resetpasswd")
|
||
async def resetpasswd(username: str = Form()):
|
||
if (await db.check_user(username)):
|
||
email = await db.get_email(username)
|
||
tkn = uuid.uuid4().hex
|
||
emails[tkn] = (username, "", datetime.now() +
|
||
timedelta(minutes=float(ACCESS_EMAIL_EXPIRE_MINUTES)), email)
|
||
|
||
email_send_lst.append((email, ROOT+"/api/resetpasswd?uid="+tkn))
|
||
return {"msg": "验证邮件已发送到邮箱,请在10分钟内完成验证", "code": 0}
|
||
else:
|
||
return {"msg": "用户名不存在", "code": 1}
|
||
|
||
|
||
@app.get("/api/getinfo")
|
||
async def get_user_info(uid: str):
|
||
username = await check_apikey(uid)
|
||
if (username == ""):
|
||
return {"code": 1, "msg": "token无效", "data": {}}
|
||
return {"code": 0, "msg": "", "data": {"username": username, "email": await db.get_email(username)}}
|
||
|
||
|
||
@app.post("/api/changepasswd")
|
||
async def changepasswd(password: str = Form(), session: Annotated[str | None, Cookie()] = None):
|
||
if (check_passwd(password)):
|
||
return {"msg": "密码强度弱,请至少包含一个大小写字符与数字且长度>8"}
|
||
if (session is not None):
|
||
username = await check_token(session)
|
||
if (username != ""):
|
||
await db.update_passwd(username, hashlib.sha256(
|
||
password.encode("utf-8")).hexdigest())
|
||
del tokens[session]
|
||
return {"msg": ""}
|
||
else:
|
||
return {"msg": "无效的登录token"}
|
||
else:
|
||
return {"msg": "无效的登录token"}
|
||
|
||
|
||
@app.get("/login")
|
||
async def login(state: str = "", client_id: str = "", redirect_url: str = "/user", session: Annotated[str | None, Cookie()] = None):
|
||
now_redirect_url = redirect_url.replace(
|
||
"https://", "").replace("http://", "").split("#")[0].rstrip("/")
|
||
if (now_redirect_url not in REDIRECT_URL_WHITELIST):
|
||
return templates.TemplateResponse("checkemail.html", {"request": {}, "msg": "无效的重定向URL"})
|
||
if (session is not None):
|
||
username = await check_token(session)
|
||
if (username != ""):
|
||
tkn = uuid.uuid4().hex
|
||
apikeys[tkn] = tokens[session]
|
||
return RedirectResponse(url=redirect_url+f"#access_token={tkn}&token_type=Bearer&state={state}")
|
||
return templates.TemplateResponse("login.html", {"request": {}, "redirect_url": redirect_url, "state": state})
|
||
|
||
|
||
@app.get("/signup")
|
||
async def login(redirect_url: str):
|
||
return templates.TemplateResponse("signup.html", {"request": {}, "redirect_url": redirect_url})
|
||
|
||
|
||
@app.get("/user")
|
||
async def login(session: Annotated[str | None, Cookie()] = None):
|
||
|
||
if (session is not None):
|
||
username = await check_token(session)
|
||
if (username == ""):
|
||
return RedirectResponse(url="/login?redirect_url=/user")
|
||
return templates.TemplateResponse("manage.html", {"request": {}})
|
||
|
||
|
||
@app.get("/resetpasswd")
|
||
async def resetpasswd(response: Response):
|
||
return templates.TemplateResponse("resetpasswd.html", {"request": {}})
|
||
|
||
|
||
@app.get("/manager/init")
|
||
async def init(key: str):
|
||
if (key != MANAGE_KEY):
|
||
return 1
|
||
await db.create_db()
|
||
return 0
|
||
reg.main(app, ROOT, apikeys)
|
||
|
||
|
||
def run():
|
||
uvicorn.run(app, host="0.0.0.0", port=8000)
|