Coverage for backend/tests/conftest.py: 92%
212 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-09-22 15:38 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-09-22 15:38 +0000
1"""
2This module is designed to facilitate testing and database interaction in the application. It includes utility functions,
3fixtures, and configurations for setting up and managing the SQLAlchemy database, test users, test tokens, and mock
4clients for API testing. The provided components streamline testing by enabling convenient access to preconfigured
5resources.
7This module is instrumental for efficiently executing unit and integration tests by establishing a robust test environment
8and providing the necessary utilities for seamless interactions with the application's data and APIs.
9"""
11import datetime as dt
12from typing import Any, Generator
14import pytest
15from fastapi import status
16from requests import Response
17from sqlalchemy import create_engine, orm
18from starlette.testclient import TestClient
19import os
21from app import models, database, schemas
22from app.eis import models as eis_models
23from app.main import app
24from app.oauth2 import create_access_token
25from tests.utils.create_data import (
26 create_users,
27 create_companies,
28 create_locations,
29 create_aggregators,
30 create_keywords,
31 create_people,
32 create_jobs,
33 create_files,
34 create_interviews,
35 create_job_alert_emails,
36 create_scraped_jobs,
37 create_service_logs,
38 create_job_application_updates,
39 create_settings,
40)
41from tests.utils.seed_database import reset_database
43SQLALCHEMY_DATABASE_URL = database.SQLALCHEMY_DATABASE_URL + "_test"
44engine = create_engine(SQLALCHEMY_DATABASE_URL)
45TestingSessionLocal = orm.sessionmaker(autocommit=False, autoflush=False, bind=engine)
48@pytest.fixture
49def session() -> Generator[orm.Session, Any, None]:
50 """Fixture that sets up and tears down a new database session for each test function.
51 This fixture creates a fresh database session by creating and dropping all tables in the
52 test database. It yields a session that can be used by test functions. After the test
53 function completes, the session is closed.
54 :yield: A new SQLAlchemy session bound to the test database."""
56 reset_database(engine)
57 db = TestingSessionLocal()
58 try:
59 yield db
60 finally:
61 db.close()
64@pytest.fixture
65def client(session) -> Generator[TestClient, Any, None]:
66 """Fixture that provides a test client with an overridden database dependency.
67 This fixture creates a test client by overriding the default database dependency
68 to use the test database session. It yields the TestClient, allowing the test
69 functions to make requests to the FastAPI application.
70 :param session: The database session fixture to override the database dependency.
71 :yield: The FastAPI TestClient with the overridden database dependency."""
73 def override_get_db() -> Generator[orm.Session, Any, None]:
74 """Override the default database dependency to use the test database session."""
75 try:
76 yield session
77 finally:
78 session.close()
80 app.dependency_overrides[database.get_db] = override_get_db
81 yield TestClient(app)
82 app.dependency_overrides.pop(database.get_db, None) # Clean up dependency override
85@pytest.fixture
86def test_users(session) -> list[models.User]:
87 """Create test user data"""
89 return create_users(session)
92@pytest.fixture
93def tokens(test_users) -> list[str]:
94 """Fixture that generates access tokens for the given test users."""
96 return [create_access_token({"user_id": user.id}) for user in test_users]
99@pytest.fixture
100def authorised_clients(client: TestClient, tokens: list[str]) -> list[TestClient]:
101 """Fixture that provides a list of authenticated test clients."""
103 clients = []
104 for token in tokens:
105 authorized_client = TestClient(client.app)
106 authorized_client.headers = {**client.headers, "Authorization": f"Bearer {token}"}
107 clients.append(authorized_client)
108 return clients
111@pytest.fixture
112def test_keywords(session, test_users) -> list[models.Keyword]:
113 """Create test keyword data"""
115 return create_keywords(session, test_users)
118@pytest.fixture
119def test_aggregators(session, test_users) -> list[models.Aggregator]:
120 """Create test aggregator data"""
122 return create_aggregators(session, test_users)
125@pytest.fixture
126def test_locations(session, test_users) -> list[models.Location]:
127 """Create test location data"""
129 return create_locations(session, test_users)
132@pytest.fixture
133def test_companies(session, test_users) -> list[models.Company]:
134 """Create test company data"""
136 return create_companies(session, test_users)
139@pytest.fixture
140def test_persons(session, test_users, test_companies) -> list[models.Person]:
141 """Create test person data"""
143 return create_people(session, test_users, test_companies)
146@pytest.fixture
147def test_files(session, test_users) -> list[models.File]:
148 """Create test files for job applications"""
150 return create_files(session, test_users)
153@pytest.fixture
154def test_jobs(
155 session, test_users, test_companies, test_locations, test_keywords, test_persons, test_aggregators, test_files
156) -> list[models.Job]:
157 """Create test job data"""
159 return create_jobs(
160 session, test_keywords, test_persons, test_users, test_companies, test_locations, test_aggregators, test_files
161 )
164@pytest.fixture
165def test_interviews(session, test_users, test_jobs, test_locations, test_persons) -> list[models.Interview]:
166 """Create test interview data"""
168 return create_interviews(session, test_persons, test_users, test_locations, test_jobs)
171@pytest.fixture
172def test_job_alert_emails(session, test_users, test_service_logs) -> list[eis_models.JobAlertEmail]:
173 """Create test job alert emails"""
175 return create_job_alert_emails(session, test_users, test_service_logs)
178@pytest.fixture
179def test_scraped_jobs(session, test_users, test_job_alert_emails) -> list[eis_models.ScrapedJob]:
180 """Create test job alert email jobs"""
182 return create_scraped_jobs(session, test_job_alert_emails, test_users)
185@pytest.fixture
186def test_service_logs(session) -> list[eis_models.EisServiceLog]:
187 """Create test service logs"""
189 return create_service_logs(session)
192@pytest.fixture
193def test_job_application_updates(session, test_users, test_jobs) -> list[models.JobApplicationUpdate]:
194 """Create test job application update data"""
196 return create_job_application_updates(session, test_users, test_jobs)
199@pytest.fixture
200def test_settings(session) -> list[models.Setting]:
201 """Create test settings data"""
203 return create_settings(session)
206def open_file(filepath: str) -> str:
207 """Helper function to open a text file from the resources directory.
208 :param filepath: The name of the file located in the resources directory"""
210 base_dir = os.path.dirname(__file__) # directory of this test file
211 filepath = os.path.join(base_dir, "resources", filepath)
212 with open(filepath, "r") as ofile:
213 return ofile.read()
216class CRUDTestBase:
217 """Base class for CRUD tests on FastAPI routes.
219 Subclasses must override:
220 - endpoint: str - base URL path for the resource (e.g. "/aggregators")
221 - schema: Pydantic model class for input validation (e.g. schemas.Aggregator)
222 - out_schema: Pydantic model class for output validation (e.g. schemas.AggregatorOut)
223 - test_data: str - name of pytest fixture providing list of test objects"""
225 endpoint: str = ""
226 schema = None
227 out_schema = None
228 test_data: str = ""
229 update_data: dict = None
230 create_data: list[dict] = None
231 add_fixture = None
233 def check_output(
234 self,
235 test_data: list[schemas.BaseModel] | list[dict] | dict | schemas.BaseModel,
236 response_data: list[dict] | dict,
237 ):
238 """Check that the output of a test matches the test data.
239 :param test_data: The test data to compare against.
240 :param response_data: The output data to compare against."""
242 if isinstance(test_data, list) and isinstance(response_data, list):
243 for d1, d2 in zip(test_data, response_data):
244 return self.check_output(d1, d2)
246 # Process the response
247 if isinstance(response_data, dict):
248 response_data = self.out_schema(**response_data)
250 # Use the test data keys for comparison
251 if isinstance(test_data, dict):
252 items = test_data.items()
253 else:
254 items = vars(test_data).items()
256 for key, value in items:
257 if key[0] != "_" and key in response_data:
258 response_value = getattr(response_data, key)
259 if isinstance(value, models.Base) or isinstance(value, list):
260 self.check_output(value, response_value)
261 elif key == "date" and isinstance(value, str):
262 # Handle datetime string comparison using fromisoformat
263 if isinstance(response_value, dt.datetime):
264 # Parse the string datetime and compare
265 parsed_value = dt.datetime.fromisoformat(value)
266 # Handle timezone differences - normalize both to the same timezone state
267 if response_value.tzinfo is not None and parsed_value.tzinfo is None:
268 parsed_value = parsed_value.replace(tzinfo=dt.timezone.utc)
269 elif response_value.tzinfo is None and parsed_value.tzinfo is not None:
270 parsed_value = parsed_value.replace(tzinfo=None)
271 assert parsed_value == response_value
272 else:
273 assert value == response_value
274 else:
275 try:
276 assert value == response_value
277 except Exception:
278 print(value)
279 print(response_value)
280 raise AssertionError
282 return None
284 # ------------------------------------------------- HELPER METHODS -------------------------------------------------
286 def get_all(self, client) -> Response:
287 """Helper method to get all items from the endpoint."""
289 return client.get(self.endpoint)
291 def get_one(self, client, item_id) -> Response:
292 """Helper method to get one item from the endpoint."""
294 return client.get(f"{self.endpoint}/{item_id}")
296 def post(self, client, data) -> Response:
297 """Helper method to post a new item to the endpoint."""
299 return client.post(self.endpoint, json=data)
301 def put(self, client: TestClient, item_id: int, data) -> Response:
302 """Helper method to update an existing item in the endpoint."""
304 return client.put(f"{self.endpoint}/{item_id}", json=data)
306 def delete(self, client, item_id) -> Response:
307 """Helper method to delete an existing item from the endpoint."""
309 return client.delete(f"{self.endpoint}/{item_id}")
311 @pytest.fixture(autouse=True)
312 def setup_method(self, request) -> None:
313 """Fixture that runs before each test method."""
315 if isinstance(self.add_fixture, list):
316 for fixture in self.add_fixture:
317 request.getfixturevalue(fixture)
319 # ------------------------------------------------------- GET ------------------------------------------------------
321 def test_get_all_success(
322 self,
323 authorised_clients,
324 request,
325 ) -> None:
326 test_data = request.getfixturevalue(self.test_data)
327 response = self.get_all(authorised_clients[0])
328 assert response.status_code == status.HTTP_200_OK
329 self.check_output(test_data, response.json())
331 def test_get_all_unauthorized(
332 self,
333 client: TestClient,
334 ) -> None:
335 response = self.get_all(client)
336 assert response.status_code == status.HTTP_401_UNAUTHORIZED
338 def test_get_one_success(
339 self,
340 authorised_clients,
341 request,
342 ) -> None:
343 test_data = request.getfixturevalue(self.test_data)
344 response = self.get_one(authorised_clients[0], test_data[0].id)
345 assert response.status_code == status.HTTP_200_OK
346 self.check_output(test_data[0], response.json())
348 def test_get_one_unauthorized(
349 self,
350 client,
351 request,
352 ) -> None:
353 test_data = request.getfixturevalue(self.test_data)
354 response = self.get_one(client, test_data[0].id)
355 assert response.status_code == status.HTTP_401_UNAUTHORIZED
357 def test_get_one_other_user(
358 self,
359 authorised_clients,
360 request,
361 ) -> None:
362 test_data = request.getfixturevalue(self.test_data)
363 response = self.get_one(authorised_clients[1], test_data[0].id)
364 assert response.status_code == status.HTTP_403_FORBIDDEN
366 def test_get_one_non_exist(
367 self,
368 authorised_clients,
369 ) -> None:
370 response = self.get_one(authorised_clients[0], 999999)
371 assert response.status_code == status.HTTP_404_NOT_FOUND
373 # ------------------------------------------------------ POST ------------------------------------------------------
375 def test_post_success(
376 self,
377 authorised_clients,
378 ) -> None:
379 """
380 Generic POST test using class attribute post_test_data.
381 Subclasses should set post_test_data = [dict(...), ...]
382 """
384 for create_data in self.create_data:
385 create_data = {key: value for key, value in create_data.items() if key not in ("id", "owner_id")}
386 response = self.post(authorised_clients[0], create_data)
387 assert response.status_code == status.HTTP_201_CREATED
388 self.check_output(create_data, response.json())
390 def test_post_unauthorized(
391 self,
392 client,
393 ) -> None:
394 response = self.post(client, {})
395 assert response.status_code == status.HTTP_401_UNAUTHORIZED
397 # ------------------------------------------------------- PUT ------------------------------------------------------
399 def test_put_success(
400 self,
401 authorised_clients,
402 request,
403 ) -> None:
404 request.getfixturevalue(self.test_data)
405 # noinspection PyTypeChecker
406 response = self.put(authorised_clients[0], self.update_data["id"], self.update_data)
407 assert response.status_code == status.HTTP_200_OK
408 self.check_output(self.update_data, response.json())
410 def test_put_empty_body(self, authorised_clients, request) -> None:
411 test_data = request.getfixturevalue(self.test_data)
412 response = self.put(authorised_clients[0], test_data[0].id, {})
413 assert response.status_code == status.HTTP_400_BAD_REQUEST
415 def test_put_non_exist(self, authorised_clients) -> None:
416 response = self.put(authorised_clients[0], 999999, {})
417 assert response.status_code == status.HTTP_404_NOT_FOUND
419 def test_put_unauthorized(self, client, request) -> None:
420 test_data = request.getfixturevalue(self.test_data)
421 response = self.put(client, test_data[0].id, {"name": "Test"})
422 assert response.status_code == status.HTTP_401_UNAUTHORIZED
424 def test_put_other_user(self, authorised_clients, request) -> None:
425 test_data = request.getfixturevalue(self.test_data)
426 response = self.put(authorised_clients[1], test_data[0].id, {"name": "Test"})
427 assert response.status_code == status.HTTP_403_FORBIDDEN
429 # ----------------------------------------------------- DELETE -----------------------------------------------------
431 def test_delete_success(self, authorised_clients, request) -> None:
432 test_data = request.getfixturevalue(self.test_data)
433 response = self.delete(authorised_clients[0], test_data[0].id)
434 assert response.status_code == status.HTTP_204_NO_CONTENT
436 def test_delete_non_exist(self, authorised_clients) -> None:
437 response = self.delete(authorised_clients[0], 999999)
438 assert response.status_code == status.HTTP_404_NOT_FOUND
440 def test_delete_unauthorized(self, client, request) -> None:
441 test_data = request.getfixturevalue(self.test_data)
442 response = self.delete(client, test_data[0].id)
443 assert response.status_code == status.HTTP_401_UNAUTHORIZED
445 def test_delete_other_user(self, authorised_clients, request) -> None:
446 test_data = request.getfixturevalue(self.test_data)
447 response = self.delete(authorised_clients[1], test_data[0].id)
448 assert response.status_code == status.HTTP_403_FORBIDDEN