292 lines
12 KiB
Python
292 lines
12 KiB
Python
import unittest
|
|
import threading
|
|
import subprocess
|
|
import time
|
|
import os
|
|
import signal
|
|
import requests
|
|
import unittest
|
|
|
|
server_process = None
|
|
API_URL = "http://localhost:7000/api"
|
|
DB_PATH = "./data/TEST.sqlite"
|
|
|
|
users = {
|
|
"admin": "admin",
|
|
"max": "12345"
|
|
}
|
|
|
|
def setUpModule():
|
|
"""Start the API server in a separate process before running tests"""
|
|
def run_server():
|
|
global server_process
|
|
# Needet to get python3 dir
|
|
env = os.environ.copy()
|
|
env.update({
|
|
"DATABASE_URL": f"sqlite:///{DB_PATH}",
|
|
"CREATE_OPTIONAL_DEFAULT_USER": "true"
|
|
})
|
|
|
|
server_process = subprocess.Popen(
|
|
["python3", "-m", "simple_chat_api"],
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE,
|
|
env=env
|
|
)
|
|
|
|
# Start server in a separate thread
|
|
server_thread = threading.Thread(target=run_server)
|
|
server_thread.daemon = True
|
|
server_thread.start()
|
|
|
|
# Wait for server to start
|
|
time.sleep(2)
|
|
|
|
def tearDownModule():
|
|
global server_process
|
|
"""Kill the API server after all tests have executed"""
|
|
if server_process:
|
|
server_process.send_signal(signal.SIGTERM)
|
|
server_process.wait()
|
|
os.remove(DB_PATH)
|
|
|
|
def build_session(user: str, password: str) -> requests.Session:
|
|
session = requests.Session()
|
|
response = session.post(f"{API_URL}/user/token", json={
|
|
"user": user,
|
|
"password": password
|
|
})
|
|
if response.status_code != 200:
|
|
raise ValueError(f"Failed to get token for user {user}; {response.text}")
|
|
return [session, response.json()]
|
|
|
|
|
|
class TestServer(unittest.TestCase):
|
|
def test_online(self):
|
|
"""Test if the server is running"""
|
|
response = requests.get(API_URL)
|
|
self.assertEqual(response.status_code, 404, "Server is not running or not reachable")
|
|
|
|
|
|
class TestAuthEndpoints(unittest.TestCase):
|
|
userSessions = {}
|
|
def __init__(self, methodName = "runTest"):
|
|
super().__init__(methodName)
|
|
|
|
def test_1_get_token(self):
|
|
"""Test the /token endpoint"""
|
|
for user,password in users.items():
|
|
with self.subTest(user=user):
|
|
try:
|
|
session, data = build_session(user, password)
|
|
self.userSessions[user] = session
|
|
except ValueError as e:
|
|
self.fail(str(e))
|
|
excepted = {
|
|
"sub": {
|
|
"user": user,
|
|
"role": "admin" if user == "admin" else "user"
|
|
},
|
|
}
|
|
self.assertEqual(data["sub"], excepted["sub"], f"Token content mismatch for user {user}")
|
|
|
|
def test_2_get_user(self):
|
|
if self.userSessions == {}:
|
|
self.skipTest("No user sessions available. Run test_get_token first.")
|
|
for user,session in self.userSessions.items():
|
|
with self.subTest(user=user):
|
|
response = session.get(f"{API_URL}/user")
|
|
self.assertEqual(response.status_code, 200, f"Failed to get user info for {user}; {response.text}")
|
|
data = response.json()
|
|
excepted = {
|
|
"name": user,
|
|
}
|
|
self.assertEqual(data, excepted, f"User info mismatch for {user}")
|
|
|
|
def test_3_get_all(self):
|
|
if self.userSessions == {}:
|
|
self.skipTest("No user sessions available. Run test_get_token first.")
|
|
|
|
for user,session in self.userSessions.items():
|
|
with self.subTest(user=user):
|
|
response = session.get(f"{API_URL}/user/getAll")
|
|
if user == "admin":
|
|
self.assertEqual(response.status_code, 200, f"Failed to get all users for admin; {response.text}")
|
|
data = response.json()
|
|
excepted = [
|
|
{"name": "admin", "role": "admin"},
|
|
{"name": "max", "role": "user"}
|
|
]
|
|
self.assertTrue(all(item in data for item in excepted), f"User list mismatch for admin")
|
|
else:
|
|
self.assertEqual(response.status_code, 401, f"Non-admin user {user} should not access /getAll; {response.text}")
|
|
|
|
def test_4_create_user(self):
|
|
if self.userSessions == {}:
|
|
self.skipTest("No user sessions available. Run test_get_token first.")
|
|
|
|
new_users = [{
|
|
"new_user": "newuser",
|
|
"new_password": "newpass",
|
|
"new_admin": False
|
|
},
|
|
{
|
|
"new_user": "admin2",
|
|
"new_password": "adminpass",
|
|
"new_admin": True
|
|
},
|
|
{
|
|
"new_user": "invalid/user",
|
|
"new_password": "pass",
|
|
"new_admin": False
|
|
}
|
|
]
|
|
|
|
for user,session in self.userSessions.items():
|
|
with self.subTest(user=user):
|
|
if user == "admin":
|
|
for new_user in new_users:
|
|
if "invalid" in new_user["new_user"]:
|
|
response = session.post(f"{API_URL}/user/add", json=new_user)
|
|
self.assertEqual(response.status_code, 400, f"Creating user with invalid name should fail; {response.text}")
|
|
continue
|
|
response = session.post(f"{API_URL}/user/add", json=new_user)
|
|
self.assertEqual(response.status_code, 200, f"Admin failed to create new user; {response.text}")
|
|
data = response.json()
|
|
self.assertIn("created successfully", data["message"], f"Created user info mismatch")
|
|
|
|
response = session.post(f"{API_URL}/user/add", json=new_user)
|
|
self.assertEqual(response.status_code, 400, f"Creating duplicate user should fail; {response.text}")
|
|
else:
|
|
self.assertEqual(response.status_code, 400, f"Non-admin user {user} should not create users; {response.text}")
|
|
|
|
def test_6_delete_user(self):
|
|
if self.userSessions == {}:
|
|
self.skipTest("No user sessions available. Run test_get_token first.")
|
|
|
|
for user,session in self.userSessions.items():
|
|
with self.subTest(user=user):
|
|
if user == "admin":
|
|
response = session.post(f"{API_URL}/user/delete/kim")
|
|
self.assertEqual(response.status_code, 200, f"Admin failed to delete user kim; {response.text}")
|
|
data = response.json()
|
|
self.assertIn("deleted successfully", data["message"], f"Deleted user info mismatch")
|
|
|
|
response = session.post(f"{API_URL}/user/delete/nonexistent")
|
|
self.assertEqual(response.status_code, 400, f"Deleting non-existent user should fail; {response.text}")
|
|
else:
|
|
response = session.post(f"{API_URL}/user/delete/ina")
|
|
self.assertEqual(response.status_code, 401, f"Non-admin user {user} should not delete users; {response.text}")
|
|
response = session.post(f"{API_URL}/user/delete/{user}")
|
|
self.assertEqual(response.status_code, 200, f"Non-admin user {user} should be able to delete self; {response.text}")
|
|
|
|
|
|
|
|
def test_5_change_password(self):
|
|
if self.userSessions == {}:
|
|
self.skipTest("No user sessions available. Run test_get_token first.")
|
|
|
|
for user,session in self.userSessions.items():
|
|
with self.subTest(user=user):
|
|
pyload = {
|
|
"old_password": users[user],
|
|
"new_password": "newpass"
|
|
}
|
|
response = session.post(f"{API_URL}/user/changePassword", json=pyload)
|
|
self.assertEqual(response.status_code, 200, f"Failed to change password; {response.text}")
|
|
|
|
pyload = {
|
|
"new_password": users[user],
|
|
"old_password": "newpass"
|
|
}
|
|
response = session.post(f"{API_URL}/user/changePassword", json=pyload)
|
|
self.assertEqual(response.status_code, 200, f"Failed to change passwords back; {response.text}")
|
|
|
|
|
|
|
|
class TestMessage(unittest.TestCase):
|
|
last_msg = None
|
|
userSessions = {}
|
|
messages = [
|
|
{"content": "Admin here, managing things."},
|
|
{"content": "Hello, this is max!"},
|
|
]
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
for user, password in users.items():
|
|
cls.userSessions[user], _ = build_session(user, password)
|
|
|
|
def test_1_post_message(self):
|
|
if not self.userSessions:
|
|
self.skipTest("No user sessions available. Run test_get_token first.")
|
|
|
|
|
|
|
|
for i, (user, session) in enumerate(self.userSessions.items()):
|
|
with self.subTest(user=user):
|
|
response = session.post(f"{API_URL}/messages/general", json=self.messages[i])
|
|
self.assertEqual(
|
|
response.status_code,
|
|
200,
|
|
f"User {user} failed to post message; {response.text}"
|
|
)
|
|
|
|
clear_json = response.json()
|
|
|
|
msg_id, data = list(clear_json.items())[-1]
|
|
|
|
TestMessage.last_msg = {
|
|
"message_id": msg_id,
|
|
"timestamp": data["timestamp"],
|
|
}
|
|
|
|
self.assertIn(
|
|
self.messages[i]["content"],
|
|
data["content"],
|
|
f"Posted message response missing content for user {user}"
|
|
)
|
|
time.sleep(1) # Ensure different timestamps
|
|
|
|
def test_2_get_messages_since(self):
|
|
if not self.userSessions or self.last_msg is None:
|
|
self.skipTest("No user sessions available or no messages posted. Run previous tests first.")
|
|
|
|
for user, session in self.userSessions.items():
|
|
with self.subTest(user=user):
|
|
response = session.get(f"{API_URL}/messages/general?since={self.last_msg['timestamp']}")
|
|
self.assertEqual(
|
|
response.status_code,
|
|
200,
|
|
f"User {user} failed to get messages since timestamp; {response.text}"
|
|
)
|
|
|
|
first_key = list(response.json().keys())[0]
|
|
|
|
self.assertEqual(
|
|
first_key,
|
|
self.last_msg["message_id"], # match full message dict
|
|
f"User {user} did not receive the expected message"
|
|
)
|
|
|
|
def test_3_get_all_messages(self):
|
|
if not self.userSessions or self.last_msg is None:
|
|
self.skipTest("No user sessions available or no messages posted. Run previous tests first.")
|
|
|
|
for user, session in self.userSessions.items():
|
|
with self.subTest(user=user):
|
|
response = session.get(f"{API_URL}/messages/general")
|
|
self.assertEqual(
|
|
response.status_code,
|
|
200,
|
|
f"User {user} failed to get all messages; {response.text}"
|
|
)
|
|
|
|
messages_data = response.json()
|
|
message_contents = [msg["content"] for _, msg in messages_data.items()]
|
|
|
|
expected_contents = [msg["content"] for msg in self.messages]
|
|
for content in expected_contents:
|
|
self.assertIn(content, message_contents, f"Message with content '{content}' not found")
|
|
|
|
self.assertEqual(len(messages_data), len(self.messages), "Number of messages doesn't match expected")
|
|
|