Coverage for backend/tests/conftest.py: 92%

212 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-09-22 15:38 +0000

1""" 

2This module is designed to facilitate testing and database interaction in the application. It includes utility functions, 

3fixtures, and configurations for setting up and managing the SQLAlchemy database, test users, test tokens, and mock 

4clients for API testing. The provided components streamline testing by enabling convenient access to preconfigured 

5resources. 

6 

7This module is instrumental for efficiently executing unit and integration tests by establishing a robust test environment 

8and providing the necessary utilities for seamless interactions with the application's data and APIs. 

9""" 

10 

11import datetime as dt 

12from typing import Any, Generator 

13 

14import pytest 

15from fastapi import status 

16from requests import Response 

17from sqlalchemy import create_engine, orm 

18from starlette.testclient import TestClient 

19import os 

20 

21from app import models, database, schemas 

22from app.eis import models as eis_models 

23from app.main import app 

24from app.oauth2 import create_access_token 

25from tests.utils.create_data import ( 

26 create_users, 

27 create_companies, 

28 create_locations, 

29 create_aggregators, 

30 create_keywords, 

31 create_people, 

32 create_jobs, 

33 create_files, 

34 create_interviews, 

35 create_job_alert_emails, 

36 create_scraped_jobs, 

37 create_service_logs, 

38 create_job_application_updates, 

39 create_settings, 

40) 

41from tests.utils.seed_database import reset_database 

42 

43SQLALCHEMY_DATABASE_URL = database.SQLALCHEMY_DATABASE_URL + "_test" 

44engine = create_engine(SQLALCHEMY_DATABASE_URL) 

45TestingSessionLocal = orm.sessionmaker(autocommit=False, autoflush=False, bind=engine) 

46 

47 

48@pytest.fixture 

49def session() -> Generator[orm.Session, Any, None]: 

50 """Fixture that sets up and tears down a new database session for each test function. 

51 This fixture creates a fresh database session by creating and dropping all tables in the 

52 test database. It yields a session that can be used by test functions. After the test 

53 function completes, the session is closed. 

54 :yield: A new SQLAlchemy session bound to the test database.""" 

55 

56 reset_database(engine) 

57 db = TestingSessionLocal() 

58 try: 

59 yield db 

60 finally: 

61 db.close() 

62 

63 

64@pytest.fixture 

65def client(session) -> Generator[TestClient, Any, None]: 

66 """Fixture that provides a test client with an overridden database dependency. 

67 This fixture creates a test client by overriding the default database dependency 

68 to use the test database session. It yields the TestClient, allowing the test 

69 functions to make requests to the FastAPI application. 

70 :param session: The database session fixture to override the database dependency. 

71 :yield: The FastAPI TestClient with the overridden database dependency.""" 

72 

73 def override_get_db() -> Generator[orm.Session, Any, None]: 

74 """Override the default database dependency to use the test database session.""" 

75 try: 

76 yield session 

77 finally: 

78 session.close() 

79 

80 app.dependency_overrides[database.get_db] = override_get_db 

81 yield TestClient(app) 

82 app.dependency_overrides.pop(database.get_db, None) # Clean up dependency override 

83 

84 

85@pytest.fixture 

86def test_users(session) -> list[models.User]: 

87 """Create test user data""" 

88 

89 return create_users(session) 

90 

91 

92@pytest.fixture 

93def tokens(test_users) -> list[str]: 

94 """Fixture that generates access tokens for the given test users.""" 

95 

96 return [create_access_token({"user_id": user.id}) for user in test_users] 

97 

98 

99@pytest.fixture 

100def authorised_clients(client: TestClient, tokens: list[str]) -> list[TestClient]: 

101 """Fixture that provides a list of authenticated test clients.""" 

102 

103 clients = [] 

104 for token in tokens: 

105 authorized_client = TestClient(client.app) 

106 authorized_client.headers = {**client.headers, "Authorization": f"Bearer {token}"} 

107 clients.append(authorized_client) 

108 return clients 

109 

110 

111@pytest.fixture 

112def test_keywords(session, test_users) -> list[models.Keyword]: 

113 """Create test keyword data""" 

114 

115 return create_keywords(session, test_users) 

116 

117 

118@pytest.fixture 

119def test_aggregators(session, test_users) -> list[models.Aggregator]: 

120 """Create test aggregator data""" 

121 

122 return create_aggregators(session, test_users) 

123 

124 

125@pytest.fixture 

126def test_locations(session, test_users) -> list[models.Location]: 

127 """Create test location data""" 

128 

129 return create_locations(session, test_users) 

130 

131 

132@pytest.fixture 

133def test_companies(session, test_users) -> list[models.Company]: 

134 """Create test company data""" 

135 

136 return create_companies(session, test_users) 

137 

138 

139@pytest.fixture 

140def test_persons(session, test_users, test_companies) -> list[models.Person]: 

141 """Create test person data""" 

142 

143 return create_people(session, test_users, test_companies) 

144 

145 

146@pytest.fixture 

147def test_files(session, test_users) -> list[models.File]: 

148 """Create test files for job applications""" 

149 

150 return create_files(session, test_users) 

151 

152 

153@pytest.fixture 

154def test_jobs( 

155 session, test_users, test_companies, test_locations, test_keywords, test_persons, test_aggregators, test_files 

156) -> list[models.Job]: 

157 """Create test job data""" 

158 

159 return create_jobs( 

160 session, test_keywords, test_persons, test_users, test_companies, test_locations, test_aggregators, test_files 

161 ) 

162 

163 

164@pytest.fixture 

165def test_interviews(session, test_users, test_jobs, test_locations, test_persons) -> list[models.Interview]: 

166 """Create test interview data""" 

167 

168 return create_interviews(session, test_persons, test_users, test_locations, test_jobs) 

169 

170 

171@pytest.fixture 

172def test_job_alert_emails(session, test_users, test_service_logs) -> list[eis_models.JobAlertEmail]: 

173 """Create test job alert emails""" 

174 

175 return create_job_alert_emails(session, test_users, test_service_logs) 

176 

177 

178@pytest.fixture 

179def test_scraped_jobs(session, test_users, test_job_alert_emails) -> list[eis_models.ScrapedJob]: 

180 """Create test job alert email jobs""" 

181 

182 return create_scraped_jobs(session, test_job_alert_emails, test_users) 

183 

184 

185@pytest.fixture 

186def test_service_logs(session) -> list[eis_models.EisServiceLog]: 

187 """Create test service logs""" 

188 

189 return create_service_logs(session) 

190 

191 

192@pytest.fixture 

193def test_job_application_updates(session, test_users, test_jobs) -> list[models.JobApplicationUpdate]: 

194 """Create test job application update data""" 

195 

196 return create_job_application_updates(session, test_users, test_jobs) 

197 

198 

199@pytest.fixture 

200def test_settings(session) -> list[models.Setting]: 

201 """Create test settings data""" 

202 

203 return create_settings(session) 

204 

205 

206def open_file(filepath: str) -> str: 

207 """Helper function to open a text file from the resources directory. 

208 :param filepath: The name of the file located in the resources directory""" 

209 

210 base_dir = os.path.dirname(__file__) # directory of this test file 

211 filepath = os.path.join(base_dir, "resources", filepath) 

212 with open(filepath, "r") as ofile: 

213 return ofile.read() 

214 

215 

216class CRUDTestBase: 

217 """Base class for CRUD tests on FastAPI routes. 

218 

219 Subclasses must override: 

220 - endpoint: str - base URL path for the resource (e.g. "/aggregators") 

221 - schema: Pydantic model class for input validation (e.g. schemas.Aggregator) 

222 - out_schema: Pydantic model class for output validation (e.g. schemas.AggregatorOut) 

223 - test_data: str - name of pytest fixture providing list of test objects""" 

224 

225 endpoint: str = "" 

226 schema = None 

227 out_schema = None 

228 test_data: str = "" 

229 update_data: dict = None 

230 create_data: list[dict] = None 

231 add_fixture = None 

232 

233 def check_output( 

234 self, 

235 test_data: list[schemas.BaseModel] | list[dict] | dict | schemas.BaseModel, 

236 response_data: list[dict] | dict, 

237 ): 

238 """Check that the output of a test matches the test data. 

239 :param test_data: The test data to compare against. 

240 :param response_data: The output data to compare against.""" 

241 

242 if isinstance(test_data, list) and isinstance(response_data, list): 

243 for d1, d2 in zip(test_data, response_data): 

244 return self.check_output(d1, d2) 

245 

246 # Process the response 

247 if isinstance(response_data, dict): 

248 response_data = self.out_schema(**response_data) 

249 

250 # Use the test data keys for comparison 

251 if isinstance(test_data, dict): 

252 items = test_data.items() 

253 else: 

254 items = vars(test_data).items() 

255 

256 for key, value in items: 

257 if key[0] != "_" and key in response_data: 

258 response_value = getattr(response_data, key) 

259 if isinstance(value, models.Base) or isinstance(value, list): 

260 self.check_output(value, response_value) 

261 elif key == "date" and isinstance(value, str): 

262 # Handle datetime string comparison using fromisoformat 

263 if isinstance(response_value, dt.datetime): 

264 # Parse the string datetime and compare 

265 parsed_value = dt.datetime.fromisoformat(value) 

266 # Handle timezone differences - normalize both to the same timezone state 

267 if response_value.tzinfo is not None and parsed_value.tzinfo is None: 

268 parsed_value = parsed_value.replace(tzinfo=dt.timezone.utc) 

269 elif response_value.tzinfo is None and parsed_value.tzinfo is not None: 

270 parsed_value = parsed_value.replace(tzinfo=None) 

271 assert parsed_value == response_value 

272 else: 

273 assert value == response_value 

274 else: 

275 try: 

276 assert value == response_value 

277 except Exception: 

278 print(value) 

279 print(response_value) 

280 raise AssertionError 

281 

282 return None 

283 

284 # ------------------------------------------------- HELPER METHODS ------------------------------------------------- 

285 

286 def get_all(self, client) -> Response: 

287 """Helper method to get all items from the endpoint.""" 

288 

289 return client.get(self.endpoint) 

290 

291 def get_one(self, client, item_id) -> Response: 

292 """Helper method to get one item from the endpoint.""" 

293 

294 return client.get(f"{self.endpoint}/{item_id}") 

295 

296 def post(self, client, data) -> Response: 

297 """Helper method to post a new item to the endpoint.""" 

298 

299 return client.post(self.endpoint, json=data) 

300 

301 def put(self, client: TestClient, item_id: int, data) -> Response: 

302 """Helper method to update an existing item in the endpoint.""" 

303 

304 return client.put(f"{self.endpoint}/{item_id}", json=data) 

305 

306 def delete(self, client, item_id) -> Response: 

307 """Helper method to delete an existing item from the endpoint.""" 

308 

309 return client.delete(f"{self.endpoint}/{item_id}") 

310 

311 @pytest.fixture(autouse=True) 

312 def setup_method(self, request) -> None: 

313 """Fixture that runs before each test method.""" 

314 

315 if isinstance(self.add_fixture, list): 

316 for fixture in self.add_fixture: 

317 request.getfixturevalue(fixture) 

318 

319 # ------------------------------------------------------- GET ------------------------------------------------------ 

320 

321 def test_get_all_success( 

322 self, 

323 authorised_clients, 

324 request, 

325 ) -> None: 

326 test_data = request.getfixturevalue(self.test_data) 

327 response = self.get_all(authorised_clients[0]) 

328 assert response.status_code == status.HTTP_200_OK 

329 self.check_output(test_data, response.json()) 

330 

331 def test_get_all_unauthorized( 

332 self, 

333 client: TestClient, 

334 ) -> None: 

335 response = self.get_all(client) 

336 assert response.status_code == status.HTTP_401_UNAUTHORIZED 

337 

338 def test_get_one_success( 

339 self, 

340 authorised_clients, 

341 request, 

342 ) -> None: 

343 test_data = request.getfixturevalue(self.test_data) 

344 response = self.get_one(authorised_clients[0], test_data[0].id) 

345 assert response.status_code == status.HTTP_200_OK 

346 self.check_output(test_data[0], response.json()) 

347 

348 def test_get_one_unauthorized( 

349 self, 

350 client, 

351 request, 

352 ) -> None: 

353 test_data = request.getfixturevalue(self.test_data) 

354 response = self.get_one(client, test_data[0].id) 

355 assert response.status_code == status.HTTP_401_UNAUTHORIZED 

356 

357 def test_get_one_other_user( 

358 self, 

359 authorised_clients, 

360 request, 

361 ) -> None: 

362 test_data = request.getfixturevalue(self.test_data) 

363 response = self.get_one(authorised_clients[1], test_data[0].id) 

364 assert response.status_code == status.HTTP_403_FORBIDDEN 

365 

366 def test_get_one_non_exist( 

367 self, 

368 authorised_clients, 

369 ) -> None: 

370 response = self.get_one(authorised_clients[0], 999999) 

371 assert response.status_code == status.HTTP_404_NOT_FOUND 

372 

373 # ------------------------------------------------------ POST ------------------------------------------------------ 

374 

375 def test_post_success( 

376 self, 

377 authorised_clients, 

378 ) -> None: 

379 """ 

380 Generic POST test using class attribute post_test_data. 

381 Subclasses should set post_test_data = [dict(...), ...] 

382 """ 

383 

384 for create_data in self.create_data: 

385 create_data = {key: value for key, value in create_data.items() if key not in ("id", "owner_id")} 

386 response = self.post(authorised_clients[0], create_data) 

387 assert response.status_code == status.HTTP_201_CREATED 

388 self.check_output(create_data, response.json()) 

389 

390 def test_post_unauthorized( 

391 self, 

392 client, 

393 ) -> None: 

394 response = self.post(client, {}) 

395 assert response.status_code == status.HTTP_401_UNAUTHORIZED 

396 

397 # ------------------------------------------------------- PUT ------------------------------------------------------ 

398 

399 def test_put_success( 

400 self, 

401 authorised_clients, 

402 request, 

403 ) -> None: 

404 request.getfixturevalue(self.test_data) 

405 # noinspection PyTypeChecker 

406 response = self.put(authorised_clients[0], self.update_data["id"], self.update_data) 

407 assert response.status_code == status.HTTP_200_OK 

408 self.check_output(self.update_data, response.json()) 

409 

410 def test_put_empty_body(self, authorised_clients, request) -> None: 

411 test_data = request.getfixturevalue(self.test_data) 

412 response = self.put(authorised_clients[0], test_data[0].id, {}) 

413 assert response.status_code == status.HTTP_400_BAD_REQUEST 

414 

415 def test_put_non_exist(self, authorised_clients) -> None: 

416 response = self.put(authorised_clients[0], 999999, {}) 

417 assert response.status_code == status.HTTP_404_NOT_FOUND 

418 

419 def test_put_unauthorized(self, client, request) -> None: 

420 test_data = request.getfixturevalue(self.test_data) 

421 response = self.put(client, test_data[0].id, {"name": "Test"}) 

422 assert response.status_code == status.HTTP_401_UNAUTHORIZED 

423 

424 def test_put_other_user(self, authorised_clients, request) -> None: 

425 test_data = request.getfixturevalue(self.test_data) 

426 response = self.put(authorised_clients[1], test_data[0].id, {"name": "Test"}) 

427 assert response.status_code == status.HTTP_403_FORBIDDEN 

428 

429 # ----------------------------------------------------- DELETE ----------------------------------------------------- 

430 

431 def test_delete_success(self, authorised_clients, request) -> None: 

432 test_data = request.getfixturevalue(self.test_data) 

433 response = self.delete(authorised_clients[0], test_data[0].id) 

434 assert response.status_code == status.HTTP_204_NO_CONTENT 

435 

436 def test_delete_non_exist(self, authorised_clients) -> None: 

437 response = self.delete(authorised_clients[0], 999999) 

438 assert response.status_code == status.HTTP_404_NOT_FOUND 

439 

440 def test_delete_unauthorized(self, client, request) -> None: 

441 test_data = request.getfixturevalue(self.test_data) 

442 response = self.delete(client, test_data[0].id) 

443 assert response.status_code == status.HTTP_401_UNAUTHORIZED 

444 

445 def test_delete_other_user(self, authorised_clients, request) -> None: 

446 test_data = request.getfixturevalue(self.test_data) 

447 response = self.delete(authorised_clients[1], test_data[0].id) 

448 assert response.status_code == status.HTTP_403_FORBIDDEN