66 lines
2.1 KiB
Python
66 lines
2.1 KiB
Python
from sqlalchemy import create_engine
|
|
from sqlalchemy.orm import declarative_base, sessionmaker
|
|
from sqlalchemy import Column, Integer, String
|
|
from dataclasses import dataclass
|
|
import time
|
|
|
|
Base = declarative_base()
|
|
|
|
|
|
class User(Base):
|
|
__tablename__ = 'users'
|
|
|
|
name = Column(String, primary_key=True)
|
|
hash = Column(String)
|
|
role = Column(String)
|
|
|
|
class Message(Base):
|
|
__tablename__ = 'messages'
|
|
|
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
room = Column(String)
|
|
content = Column(String)
|
|
user = Column(String)
|
|
timestamp = Column(Integer)
|
|
|
|
|
|
class DbConnector:
|
|
def __init__(self, db_url: str):
|
|
self.engine = create_engine(db_url)
|
|
Base.metadata.create_all(self.engine)
|
|
self.session = sessionmaker(bind=self.engine)()
|
|
self._create_defaults()
|
|
|
|
def _create_defaults(self):
|
|
try:
|
|
self.add_user(name="admin", hash="$2b$12$IcUr5w7pIFaXaGVFP5yVV.b.sIYjDbETR3l2PKgWO4nkrHU.1HmFa", role="admin")
|
|
except ValueError:
|
|
print("Default admin user already exists")
|
|
|
|
def get_user(self, name: str) -> User | None:
|
|
return self.session.query(User).filter(User.name==name).first()
|
|
|
|
def add_user(self, name: str, hash: str, role: str = "user"):
|
|
if self.get_user(name):
|
|
raise ValueError("User already exists")
|
|
new_user = User()
|
|
new_user.name = name
|
|
new_user.hash = hash
|
|
new_user.role = role
|
|
self.session.add(new_user)
|
|
self.session.commit()
|
|
|
|
def add_msg_to_room(self, room: str, msg: str, user: str):
|
|
new_msg = Message(room=room, content=msg, user=user, timestamp=int(time.time()))
|
|
self.session.add(new_msg)
|
|
self.session.commit()
|
|
self.session.refresh(new_msg) # Refresh to get the auto-incremented ID
|
|
return new_msg
|
|
|
|
|
|
def get_messages_from_room(self, room: str, since: int| None = None) -> list[Message]:
|
|
query = self.session.query(Message).filter(Message.room == room)
|
|
if since is not None:
|
|
query = query.filter(Message.timestamp >= since)
|
|
return query.all()
|