sacn_accout_system/server/main.py

270 lines
9.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi.security import OAuth2PasswordBearer
from fastapi import FastAPI, Cookie, Response, Form
from fastapi.templating import Jinja2Templates
from fastapi.responses import RedirectResponse, HTMLResponse
from datetime import timedelta, datetime
from contextlib import asynccontextmanager
from . import db
from . import email as email_sys
from . import reg
import hashlib
import uuid
from typing import Annotated
import re
import asyncio
import uvicorn
from .cfg import config
ALGORITHM = config["common"]["algorithm"]
ACCESS_TOKEN_EXPIRE_MINUTES = config["common"]["access_token_expire_minutes"]
ACCESS_EMAIL_EXPIRE_MINUTES = config["common"]["access_email_expire_minutes"]
ROOT = config["common"]["root"]
CLEAN_TIMEOUT = config["common"]["clean_timeout"]
MANAGE_KEY = config["common"]["manage_key"]
REDIRECT_URL_WHITELIST = config["common"]["redirect_url_whitelist"]
tokens = {}
apikeys = {}
emails = {}
email_send_lst = []
async def clean_sys():
while 1:
await asyncio.sleep(CLEAN_TIMEOUT)
for k in list(tokens.keys()):
if (tokens[k][1] < datetime.now()):
del tokens[k]
for k in list(apikeys.keys()):
if (apikeys[k][1] < datetime.now()):
del apikeys[k]
for k in list(emails.keys()):
if (emails[k][2] < datetime.now()):
del emails[k]
async def send_email():
while 1:
if (len(email_send_lst) > 0):
infos = email_send_lst.pop(0)
asyncio.create_task(email_sys.sendemail(infos[0], infos[1]))
await asyncio.sleep(0.1)
else:
await asyncio.sleep(5)
@asynccontextmanager
async def lifespan(app: FastAPI):
await db.connect_db()
asyncio.create_task(clean_sys())
asyncio.create_task(send_email())
yield
app = FastAPI(lifespan=lifespan)
templates = Jinja2Templates(directory="src")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
def check_passwd(passwd: str):
if (len(passwd) < 8):
return 1
pattern = r'^(?![a-zA-Z]+$)(?!\d+$)(?![^\da-zA-Z\s]+$).{8,40}$'
if re.match(pattern, passwd):
return 0
else:
return 1
async def authenticate_user(username: str, password: str):
hashed_password = await db.get_user(username)
if not hashed_password:
return False
return hashed_password == hashlib.sha256(
password.encode("utf-8")).hexdigest()
async def create_token(username: str):
tkn = uuid.uuid4().hex
tokens[tkn] = (username, datetime.now() +
timedelta(minutes=float(ACCESS_TOKEN_EXPIRE_MINUTES)))
return tkn
async def check_token(tkn: str):
res = tokens.get(tkn, None)
if (res is None):
return ""
if (res[1] < datetime.now()):
del tokens[tkn]
return ""
return res[0]
async def check_apikey(tkn: str):
res = apikeys.get(tkn, None)
if (res is None):
return ""
if (res[1] < datetime.now()):
del apikeys[tkn]
return ""
return res[0]
@app.post("/api/login")
async def login_callback(response: Response, username: str = Form(), password: str = Form()):
if (await authenticate_user(username, password)):
tokennow = await create_token(username)
tkn = uuid.uuid4().hex
apikeys[tkn] = tokens[tokennow]
response.set_cookie("session", tokennow)
return {"msg": "", "key": tkn}
else:
return {"msg": "用户名或密码错误", "key": ""}
regex = re.compile(
r'([A-Za-z0-9]+[.-_])*[A-Za-z0-9]+@[A-Za-z0-9-]+(\.[A-Z|a-z]{2,})+')
@app.post("/api/signup")
async def login_callback(username: str = Form(), password: str = Form(), email: str = Form()):
if (check_passwd(password)):
return {"msg": "密码强度弱,请至少包含一个大小写字符与数字且长度>8", "code": 1}
if (not re.fullmatch(regex, email)):
return {"msg": "邮箱不合法", "code": 1}
if not (await db.check_user(username)):
tkn = uuid.uuid4().hex
emails[tkn] = (username, hashlib.sha256(
password.encode("utf-8")).hexdigest(), datetime.now() +
timedelta(minutes=float(ACCESS_EMAIL_EXPIRE_MINUTES)), email)
email_send_lst.append((email, ROOT+"/api/checkemail?uid="+tkn))
return {"msg": "验证邮件已发送到邮箱请在10分钟内完成验证", "code": 0}
else:
return {"msg": "用户名重复", "code": 1}
@app.get("/api/checkemail")
async def checkemail(uid: str):
if (uid not in emails):
return templates.TemplateResponse("checkemail.html", {"request": {}, "msg": "不存在的注册id"})
if (emails[uid][2] < datetime.now()):
del emails[uid]
return templates.TemplateResponse("checkemail.html", {"request": {}, "msg": "链接已过期"})
if (emails[uid][1] == ""):
return templates.TemplateResponse("checkemail.html", {"request": {}, "msg": "不存在的注册id"})
if await db.create_user(emails[uid][0], emails[uid][1], emails[uid][3]) == 0:
del emails[uid]
return templates.TemplateResponse("checkemail.html", {"request": {}, "msg": "创建成功"})
else:
del emails[uid]
return templates.TemplateResponse("checkemail.html", {"request": {}, "msg": "重复的用户名"})
@app.get("/api/resetpasswd", response_class=HTMLResponse)
async def resetpasswd(uid: str, response: Response):
if (uid not in emails):
return templates.TemplateResponse("checkemail.html", {"request": {}, "msg": "不存在的验证id"})
if (emails[uid][2] < datetime.now()):
del emails[uid]
return templates.TemplateResponse("checkemail.html", {"request": {}, "msg": "链接已过期"})
if (emails[uid][1] != ""):
return templates.TemplateResponse("checkemail.html", {"request": {}, "msg": "不存在的注册id"})
tokennow = await create_token(emails[uid][0])
tkn = uuid.uuid4().hex
apikeys[tkn] = tokens[tokennow]
response.set_cookie("session", tokennow)
return '<html><head><meta http-equiv="refresh" content="0;url=/user"><title>正在跳转</title></head><body>正在跳转</body></html>'
@app.post("/api/send_resetpasswd")
async def resetpasswd(username: str = Form()):
if (await db.check_user(username)):
email = await db.get_email(username)
tkn = uuid.uuid4().hex
emails[tkn] = (username, "", datetime.now() +
timedelta(minutes=float(ACCESS_EMAIL_EXPIRE_MINUTES)), email)
email_send_lst.append((email, ROOT+"/api/resetpasswd?uid="+tkn))
return {"msg": "验证邮件已发送到邮箱请在10分钟内完成验证", "code": 0}
else:
return {"msg": "用户名不存在", "code": 1}
@app.get("/api/getinfo")
async def get_user_info(uid: str):
username = await check_apikey(uid)
if (username == ""):
return {"code": 1, "msg": "token无效", "data": {}}
return {"code": 0, "msg": "", "data": {"username": username, "email": await db.get_email(username)}}
@app.post("/api/changepasswd")
async def changepasswd(password: str = Form(), session: Annotated[str | None, Cookie()] = None):
if (check_passwd(password)):
return {"msg": "密码强度弱,请至少包含一个大小写字符与数字且长度>8"}
if (session is not None):
username = await check_token(session)
if (username != ""):
await db.update_passwd(username, hashlib.sha256(
password.encode("utf-8")).hexdigest())
del tokens[session]
return {"msg": ""}
else:
return {"msg": "无效的登录token"}
else:
return {"msg": "无效的登录token"}
@app.get("/login")
async def login(state: str = "", client_id: str = "", redirect_url: str = "/user", session: Annotated[str | None, Cookie()] = None):
now_redirect_url = redirect_url.replace(
"https://", "").replace("http://", "").split("#")[0].rstrip("/")
if (now_redirect_url not in REDIRECT_URL_WHITELIST):
return templates.TemplateResponse("checkemail.html", {"request": {}, "msg": "无效的重定向URL"})
if (session is not None):
username = await check_token(session)
if (username != ""):
tkn = uuid.uuid4().hex
apikeys[tkn] = tokens[session]
return RedirectResponse(url=redirect_url+f"#access_token={tkn}&token_type=Bearer&state={state}")
return templates.TemplateResponse("login.html", {"request": {}, "redirect_url": redirect_url, "state": state})
@app.get("/signup")
async def login(redirect_url: str):
return templates.TemplateResponse("signup.html", {"request": {}, "redirect_url": redirect_url})
@app.get("/user")
async def login(session: Annotated[str | None, Cookie()] = None):
if (session is not None):
username = await check_token(session)
if (username == ""):
return RedirectResponse(url="/login?redirect_url=/user")
return templates.TemplateResponse("manage.html", {"request": {}})
@app.get("/resetpasswd")
async def resetpasswd(response: Response):
return templates.TemplateResponse("resetpasswd.html", {"request": {}})
@app.get("/manager/init")
async def init(key: str):
if (key != MANAGE_KEY):
return 1
await db.create_db()
return 0
reg.main(app, ROOT, apikeys)
def run():
uvicorn.run(app, host="0.0.0.0", port=8000)