from fastapi.security import OAuth2PasswordBearer from fastapi import FastAPI, Cookie, Response, Form from fastapi.templating import Jinja2Templates from fastapi.responses import RedirectResponse, HTMLResponse from datetime import timedelta, datetime from contextlib import asynccontextmanager from . import db from . import email as email_sys from . import cfg import hashlib import uuid from typing import Annotated import re import asyncio import uvicorn from . import cfg import pygtrie import os import sys import importlib def load_cfg(): global ALGORITHM global ACCESS_TOKEN_EXPIRE_MINUTES global ACCESS_EMAIL_EXPIRE_MINUTES global ROOT global CLEAN_TIMEOUT global MANAGE_KEY global REDIRECT_URL_WHITELIST ALGORITHM = cfg.config["common"]["algorithm"] ACCESS_TOKEN_EXPIRE_MINUTES = cfg.config["common"]["access_token_expire_minutes"] ACCESS_EMAIL_EXPIRE_MINUTES = cfg.config["common"]["access_email_expire_minutes"] ROOT = cfg.config["common"]["root"] CLEAN_TIMEOUT = cfg.config["common"]["clean_timeout"] MANAGE_KEY = cfg.config["common"]["manage_key"] REDIRECT_URL_WHITELIST = cfg.config["common"]["redirect_url_whitelist"] load_cfg() tokens = pygtrie.StringTrie() apikeys = pygtrie.StringTrie() emails = pygtrie.StringTrie() email_send_lst = [] def prep_uuid(uuid: str): return '/'.join(list(uuid)) def clean_uuid(uuid: str): return uuid.replace("/", "") async def clean_sys(): while 1: await asyncio.sleep(CLEAN_TIMEOUT) sys.stderr.write("==> clean\n") for k, v in tokens.items(): if (v[1] < datetime.now()): del tokens[k] for k, v in apikeys.items(): if (v[1] < datetime.now()): del apikeys[k] for k, v in emails.items(): if (v[2] < datetime.now()): del emails[k] async def send_email(): while 1: if (len(email_send_lst) > 0): infos = email_send_lst.pop(0) asyncio.create_task(email_sys.sendemail(infos[0], infos[1])) await asyncio.sleep(0.1) else: await asyncio.sleep(5) @asynccontextmanager async def lifespan(app: FastAPI): await db.connect_db() asyncio.create_task(clean_sys()) asyncio.create_task(send_email()) yield app = FastAPI(lifespan=lifespan) templates = Jinja2Templates(directory="src") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") def check_passwd(passwd: str): if (len(passwd) < 8): return 1 pattern = r'^(?![a-zA-Z]+$)(?!\d+$)(?![^\da-zA-Z\s]+$).{8,40}$' if re.match(pattern, passwd): return 0 else: return 1 async def authenticate_user(username: str, password: str): hashed_password = await db.get_user(username) if not hashed_password: return False return hashed_password == hashlib.sha256( password.encode("utf-8")).hexdigest() async def create_token(username: str): tkn = prep_uuid(uuid.uuid4().hex) tokens[tkn] = (username, datetime.now() + timedelta(minutes=float(ACCESS_TOKEN_EXPIRE_MINUTES))) return tkn async def check_token(tkn: str): res = tokens.get(tkn, None) if (res is None): return "" if (res[1] < datetime.now()): del tokens[tkn] return "" return res[0] async def check_apikey(tkn: str): res = apikeys.get(tkn, None) if (res is None): return "" if (res[1] < datetime.now()): del apikeys[tkn] return "" return res[0] @app.post("/api/login") async def login_callback(response: Response, username: str = Form(), password: str = Form()): if (await authenticate_user(username, password)): tokennow = await create_token(username) tkn = prep_uuid(uuid.uuid4().hex) apikeys[tkn] = tokens[tokennow] response.set_cookie("session", clean_uuid(tokennow)) return {"msg": "", "key": tkn} else: return {"msg": "用户名或密码错误", "key": ""} regex = re.compile( r'([A-Za-z0-9]+[.-_])*[A-Za-z0-9]+@[A-Za-z0-9-]+(\.[A-Z|a-z]{2,})+') @app.post("/api/signup") async def login_callback(username: str = Form(), password: str = Form(), email: str = Form()): if (check_passwd(password)): return {"msg": "密码强度弱,请至少包含一个大小写字符与数字且长度>8", "code": 1} if (not re.fullmatch(regex, email)): return {"msg": "邮箱不合法", "code": 1} if not (await db.check_user(username)): tkn = prep_uuid(uuid.uuid4().hex) emails[tkn] = (username, hashlib.sha256( password.encode("utf-8")).hexdigest(), datetime.now() + timedelta(minutes=float(ACCESS_EMAIL_EXPIRE_MINUTES)), email) email_send_lst.append( (email, ROOT+"/api/checkemail?uid="+clean_uuid(tkn))) return {"msg": "验证邮件已发送到邮箱,请在10分钟内完成验证", "code": 0} else: return {"msg": "用户名重复", "code": 1} @app.get("/api/checkemail") async def checkemail(uid: str): uid = prep_uuid(uid) if (uid not in emails): return templates.TemplateResponse("checkemail.html", {"request": {}, "msg": "不存在的注册id"}) if (emails[uid][2] < datetime.now()): del emails[uid] return templates.TemplateResponse("checkemail.html", {"request": {}, "msg": "链接已过期"}) if (emails[uid][1] == ""): return templates.TemplateResponse("checkemail.html", {"request": {}, "msg": "不存在的注册id"}) if await db.create_user(emails[uid][0], emails[uid][1], emails[uid][3]) == 0: del emails[uid] return templates.TemplateResponse("checkemail.html", {"request": {}, "msg": "创建成功"}) else: del emails[uid] return templates.TemplateResponse("checkemail.html", {"request": {}, "msg": "重复的用户名"}) @app.get("/api/resetpasswd", response_class=HTMLResponse) async def resetpasswd(uid: str, response: Response): uid = prep_uuid(uid) if (uid not in emails): return templates.TemplateResponse("checkemail.html", {"request": {}, "msg": "不存在的验证id"}) if (emails[uid][2] < datetime.now()): del emails[uid] return templates.TemplateResponse("checkemail.html", {"request": {}, "msg": "链接已过期"}) if (emails[uid][1] != ""): return templates.TemplateResponse("checkemail.html", {"request": {}, "msg": "不存在的注册id"}) tokennow = await create_token(emails[uid][0]) tkn = prep_uuid(uuid.uuid4().hex) apikeys[tkn] = tokens[tokennow] response.set_cookie("session", clean_uuid(tokennow)) del emails[uid] return '正在跳转正在跳转' @app.post("/api/send_resetpasswd") async def resetpasswd(username: str = Form()): if (await db.check_user(username)): email = await db.get_email(username) tkn = prep_uuid(uuid.uuid4().hex) emails[tkn] = (username, "", datetime.now() + timedelta(minutes=float(ACCESS_EMAIL_EXPIRE_MINUTES)), email) email_send_lst.append( (email, ROOT+"/api/resetpasswd?uid="+clean_uuid(tkn))) return {"msg": "验证邮件已发送到邮箱,请在10分钟内完成验证", "code": 0} else: return {"msg": "用户名不存在", "code": 1} @app.get("/api/getinfo") async def get_user_info(uid: str): uid = prep_uuid(uid) username = await check_apikey(uid) if (username == ""): return {"code": 1, "msg": "token无效", "data": {}} return {"code": 0, "msg": "", "data": {"username": username, "email": await db.get_email(username)}} @app.post("/api/changepasswd") async def changepasswd(password: str = Form(), session: Annotated[str | None, Cookie()] = None): if (check_passwd(password)): return {"msg": "密码强度弱,请至少包含一个大小写字符与数字且长度>8"} if (session is not None): session = prep_uuid(session) username = await check_token(session) if (username != ""): await db.update_passwd(username, hashlib.sha256( password.encode("utf-8")).hexdigest()) del tokens[session] return {"msg": ""} else: return {"msg": "无效的登录token"} else: return {"msg": "无效的登录token"} @app.get("/login") async def login(state: str = "", client_id: str = "", redirect_url: str = "/user", session: Annotated[str | None, Cookie()] = None): now_redirect_url = redirect_url.replace( "https://", "").replace("http://", "").split("#")[0].rstrip("/") if (now_redirect_url not in REDIRECT_URL_WHITELIST): return templates.TemplateResponse("checkemail.html", {"request": {}, "msg": "无效的重定向URL"}) if (session is not None): session = prep_uuid(session) username = await check_token(session) if (username != ""): tkn = prep_uuid(uuid.uuid4().hex) apikeys[tkn] = tokens[session] return RedirectResponse(url=redirect_url+f"#access_token={tkn}&token_type=Bearer&state={state}") return templates.TemplateResponse("login.html", {"request": {}, "redirect_url": redirect_url, "state": state}) @app.get("/signup") async def login(redirect_url: str): return templates.TemplateResponse("signup.html", {"request": {}, "redirect_url": redirect_url}) @app.get("/user") async def login(session: Annotated[str | None, Cookie()] = None): if (session is not None): session = prep_uuid(session) username = await check_token(session) if (username == ""): return RedirectResponse(url="/login?redirect_url=/user") return templates.TemplateResponse("manage.html", {"request": {}}) @app.get("/resetpasswd") async def resetpasswd(response: Response): return templates.TemplateResponse("resetpasswd.html", {"request": {}}) @app.get("/manager/init") async def init(key: str): if (key != MANAGE_KEY): return 1 await db.create_db() return 0 plugins = [] async def reload_cfg(): global plugins cfg.reload() load_cfg() email_sys.load_cfg() for i in plugins: await i.reload() @app.get("/manager/reload") async def reload(key: str): sys.stderr.write("==> reload config\n") if (key != MANAGE_KEY): return 1 await reload_cfg() return 0 for i in os.listdir("plugin"): if (len(i.split(".")) != 2 or i.split(".")[1] != 'py'): continue if (i.split(".")[0] not in cfg.config): sys.stderr.write("--> disable "+i.split(".")[0]+"\n") continue sys.stderr.write("==> load "+i.split(".")[0]+"\n") plugins.append(importlib.import_module("plugin."+i.split(".")[0])) plugins[-1].main(app, ROOT, apikeys) def run(): uvicorn.run(app, host="0.0.0.0", port=8000)