Coverage for backend / app / routers / utility.py: 88%
225 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-17 21:34 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-17 21:34 +0000
1"""CRUD router generator for data table operations.
3Provides a factory function to generate FastAPI routers with standard CRUD endpoints,
4including user ownership validation, query filtering, and many-to-many relationship handling."""
6from typing import Any, Callable
8from fastapi import APIRouter, Depends, HTTPException
9from sqlalchemy.exc import IntegrityError
10from sqlalchemy.orm import Query
11from sqlalchemy.orm import Session
12from starlette import status
13from starlette.requests import Request
15from app import database, models
16from app.core import oauth2
18NOT_ALLOWED_EXCEPTION = HTTPException(
19 status_code=status.HTTP_403_FORBIDDEN,
20 detail="Not authorised to perform requested action",
21)
24def filter_out_non_owned(
25 entry: Any,
26 current_user_id: int,
27 processed_objects: set = None,
28) -> Any:
29 """Recursively filter out related objects that don't belong to the current user.
30 :param entry: The SQLAlchemy model instance to filter
31 :param current_user_id: The ID of the current user
32 :param processed_objects: Set to track processed objects (prevents infinite recursion)
33 :return: The filtered SQLAlchemy model instance"""
35 if processed_objects is None:
36 processed_objects = set()
38 # Avoid infinite recursion
39 obj_id = id(entry)
40 if obj_id in processed_objects:
41 return entry
42 processed_objects.add(obj_id)
44 # Get the SQLAlchemy mapper for this object
45 if not hasattr(entry, "__mapper__"):
46 return entry
48 mapper = entry.__mapper__
50 # Iterate through all relationships
51 for relationship_prop in mapper.relationships:
52 attr_name = relationship_prop.key
53 related_value = getattr(entry, attr_name, None)
55 if related_value is None:
56 continue
58 # Handle list relationships (one-to-many, many-to-many)
59 if isinstance(related_value, list):
60 filtered_list = []
61 for item in related_value:
62 # Check if item has owner_id and if it matches current user
63 if hasattr(item, "owner_id"):
64 if item.owner_id == current_user_id:
65 # Recursively filter this item too
66 filtered_item = filter_out_non_owned(item, current_user_id, processed_objects)
67 filtered_list.append(filtered_item)
68 else:
69 # Keep items without owner_id (like system data)
70 filtered_item = filter_out_non_owned(item, current_user_id, processed_objects)
71 filtered_list.append(filtered_item)
73 # Replace the relationship with filtered list
74 setattr(entry, attr_name, filtered_list)
76 # Handle single relationships (many-to-one, one-to-one)
77 else:
78 if hasattr(related_value, "owner_id"):
79 if related_value.owner_id != current_user_id:
80 # Set to None if not owned by current user
81 setattr(entry, attr_name, None)
82 else:
83 # Recursively filter the related object
84 filtered_related = filter_out_non_owned(related_value, current_user_id, processed_objects)
85 setattr(entry, attr_name, filtered_related)
86 else:
87 # Keep and recursively filter items without owner_id
88 filtered_related = filter_out_non_owned(related_value, current_user_id, processed_objects)
89 setattr(entry, attr_name, filtered_related)
91 return entry
94def filter_query(
95 query: Query,
96 table_model,
97 filter_params: dict,
98) -> Query:
99 """Apply filters to a SQLAlchemy query based on provided parameters.
100 :param query: The SQLAlchemy query object.
101 :param table_model: The SQLAlchemy model class representing the table.
102 :param filter_params: Dict of parameters to filter by (e.g., from request query).
103 :return: The filtered query object."""
105 def convert_value(value, column) -> Any:
106 """Convert a single value to the appropriate type based on column type."""
107 # Handle null values
108 if isinstance(value, str) and value.lower() == "null":
109 return None
111 # Try to convert to appropriate type based on column type
112 try:
113 if hasattr(column.type, "python_type"):
114 python_type = column.type.python_type
115 if python_type == int:
116 return int(value)
117 elif python_type == float:
118 return float(value)
119 elif python_type == bool:
120 return value.lower() in ("true", "1", "yes", "on")
121 return value
122 except (ValueError, TypeError, AttributeError):
123 return value
125 for param_name, param_value in filter_params.items():
126 if not hasattr(table_model, param_name):
127 continue
129 col = getattr(table_model, param_name)
131 # Handle list values
132 if isinstance(param_value, list):
133 converted_values = [convert_value(val, col) for val in param_value]
134 query = query.filter(col.in_(converted_values))
136 # Handle single values
137 else:
138 converted_value = convert_value(param_value, col)
139 if converted_value is None:
140 query = query.filter(col.is_(None))
141 else:
142 query = query.filter(col == converted_value)
144 return query
147def assert_admin(user: models.User) -> None:
148 """Check if the user is an admin.
149 :param user: The user to check."""
151 if not user.is_admin:
152 raise NOT_ALLOWED_EXCEPTION
155def generate_data_table_crud_router(
156 *,
157 table_model,
158 create_schema=None,
159 update_schema=None,
160 out_schema=None,
161 endpoint: str,
162 not_found_msg: str = "Entry not found",
163 many_to_many_fields: dict | None = None,
164 router: APIRouter | None = None,
165 admin_only: bool = False,
166 allowed_actions: list[str] | None = None,
167 transform: None | Callable = None,
168 check_settings: None | Callable = None,
169) -> APIRouter:
170 """Generate a FastAPI router with standard CRUD endpoints for a given table.
171 :param table_model: SQLAlchemy model class representing the database table.
172 :param create_schema: Pydantic schema used for creating new entries.
173 :param update_schema: Pydantic schema used for updating existing entries.
174 :param out_schema: Pydantic schema used for serialising output.
175 :param endpoint: Endpoint name (used as route prefix and tag).
176 :param not_found_msg: Default message when an entry is not found.
177 :param many_to_many_fields: Dict defining M2M relationships.
178 Format: {'field_name': {
179 'table': association_table,
180 'local_key': 'local_foreign_key',
181 'remote_key': 'remote_foreign_key'
182 'related_model': RelatedModelClass}}
183 :param router: Optional router to which the endpoints will be added.
184 :param admin_only: If True, restrict access to admin users only.
185 :param allowed_actions: List of allowed actions (get, post, put, delete). If None, all are allowed.
186 :param transform: Optional function to transform the data before saving.
187 :param check_settings: Optional function to check that the operation is allowed by the settings before create or update.
188 :return: Configured APIRouter instance with CRUD endpoints."""
190 if router is None:
191 router = APIRouter(prefix=f"/{endpoint}", tags=[endpoint])
193 NOT_FOUND_EXCEPTION = HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=not_found_msg)
195 if allowed_actions is None:
196 allowed_actions = ["get", "post", "put", "delete"]
198 def check_admin(current_user: models.User) -> None:
199 """Raise an exception if the table is for admins only and if the user is not an admin.
200 :param current_user: The current authenticated user."""
202 if admin_only and not current_user.is_admin:
203 raise NOT_ALLOWED_EXCEPTION
205 def check_ownership(
206 entry: Any,
207 current_user: models.User,
208 ) -> None:
209 """Raise an exception if the user does not own the entry (if the entry has an owner_id field).
210 :param entry: The database entry.
211 :param current_user: The current authenticated user."""
213 if hasattr(entry, "owner_id") and entry.owner_id != current_user.id:
214 raise NOT_ALLOWED_EXCEPTION
216 def upsert_many_to_many(
217 db: Session,
218 entry_id: int,
219 item_data: dict,
220 owner_id: int,
221 clear_existing: bool = False,
222 ):
223 """Handle creation or update of many-to-many relationships with owner check.
224 Ensures that if the entry being linked has an owner_id, it matches the current user's ID.
225 :param db: Database session
226 :param entry_id: ID of the entry to which the relationships are being added
227 :param item_data: Data containing the relationships to be added
228 :param owner_id: ID of the current user (owner)
229 :param clear_existing: If True, delete existing relationships before adding new ones"""
231 if not many_to_many_fields or not hasattr(table_model, "owner_id"):
232 return
234 for field_name, m2m_config in many_to_many_fields.items():
235 if item_data.get(field_name) is not None:
236 association_table = m2m_config["table"]
237 local_key = m2m_config["local_key"]
238 remote_key = m2m_config["remote_key"]
239 related_model = m2m_config["related_model"]
241 if clear_existing:
242 db.execute(association_table.delete().where(getattr(association_table.c, local_key) == entry_id))
244 for value_id in item_data[field_name]:
245 related_obj = db.query(related_model).filter(related_model.id == value_id).first()
246 if not related_obj:
247 continue
249 if not hasattr(related_obj, "owner_id") or getattr(related_obj, "owner_id") == owner_id:
250 db.execute(association_table.insert().values(**{local_key: entry_id, remote_key: value_id}))
252 # ------------------------------------------------------- GET ------------------------------------------------------
254 if "get" in allowed_actions or "get_all" in allowed_actions:
256 @router.get("/", response_model=list[out_schema]) # noqa
257 def get_all(
258 request: Request,
259 db: Session = Depends(database.get_db),
260 current_user: models.User = Depends(oauth2.get_current_user),
261 limit: int | None = None,
262 ):
263 """Retrieve all entries for the current user.
264 :param request: FastAPI request object to access query parameters
265 :param db: Database session.
266 :param current_user: Authenticated user.
267 :param limit: Maximum number of entries to return.
268 :return: List of entries.
269 :raises: HTTPException with a 403 status code if not authorised to perform the requested action."""
271 # Check if admin rights are needed
272 check_admin(current_user)
274 # Start with base query
275 if not admin_only:
276 query = db.query(table_model).filter(table_model.owner_id == current_user.id)
277 elif current_user.is_admin:
278 query = db.query(table_model)
279 else:
280 raise NOT_ALLOWED_EXCEPTION
282 # Get all query parameters and handle multiple values
283 filter_params = {}
284 for key in request.query_params.keys():
285 if key != "limit": # Skip limit parameter
286 values = request.query_params.getlist(key)
287 # If only one value, store as single value; otherwise store as list
288 filter_params[key] = values[0] if len(values) == 1 else values
290 query = filter_query(query, table_model, filter_params)
292 results = query.limit(limit).all()
293 if current_user.is_admin:
294 return results
295 else:
296 return [filter_out_non_owned(result, current_user.id) for result in results]
298 if "get" in allowed_actions or "get_one" in allowed_actions:
300 @router.get("/{entry_id}", response_model=out_schema)
301 def get_one(
302 entry_id: int,
303 db: Session = Depends(database.get_db),
304 current_user: models.User = Depends(oauth2.get_current_user),
305 ):
306 """Get an entry by ID.
307 :param entry_id: The entry ID.
308 :param db: The database session.
309 :param current_user: The current user.
310 :returns: The entry if found.
311 :raises: HTTPException with a 404 status code if the entry is not found.
312 :raises: HTTPException with a 403 status code if not authorised to perform the requested action."""
314 # Check if admin rights are needed
315 check_admin(current_user)
317 # Get the entry
318 entry = db.query(table_model).filter(table_model.id == entry_id).first()
319 if not entry:
320 raise NOT_FOUND_EXCEPTION
322 # Ensure that the user is authorised to view this entry
323 check_ownership(entry, current_user)
325 if current_user.is_admin:
326 return entry
327 else:
328 return filter_out_non_owned(entry, current_user.id)
330 # ------------------------------------------------------ POST ------------------------------------------------------
332 if "post" in allowed_actions:
334 @router.post("/", status_code=status.HTTP_201_CREATED, response_model=out_schema)
335 def create(
336 item: create_schema, # noqa
337 db: Session = Depends(database.get_db),
338 current_user: models.User = Depends(oauth2.get_current_user),
339 ):
340 """Create a new entry.
341 :param item: Data for the new entry.
342 :param db: Database session.
343 :param current_user: Authenticated user.
344 :return: The created entry.
345 :raises: HTTPException with a 403 status code if not authorised to perform the requested action."""
347 # Check if admin rights are needed
348 check_admin(current_user)
350 # Check settings if a check function is provided
351 if check_settings:
352 check_settings(db, item)
354 # Enforce a maximum of 10,000 entries
355 counts = db.query(table_model).count()
356 if counts >= 10_000:
357 raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Maximum number of entries reached")
359 # Extract the item data and exclude many-to-many fields from main creation
360 item_dict = item.model_dump()
361 if transform:
362 item_dict.update(transform(item_dict, db))
364 # Remove many-to-many fields from main creation data
365 main_data = item_dict.copy()
366 m2m_data = {}
368 if many_to_many_fields:
369 for field_name in many_to_many_fields.keys():
370 if field_name in main_data:
371 m2m_data[field_name] = main_data.pop(field_name)
373 # Add the owner id if the table has an owner_id field
374 if hasattr(table_model, "owner_id"):
375 main_data["owner_id"] = current_user.id
377 # Create the main entry
378 new_entry = table_model(**main_data)
379 db.add(new_entry)
380 try:
381 db.commit()
382 db.refresh(new_entry)
383 except IntegrityError as e:
384 db.rollback()
385 if "duplicate key value violates unique constraint" in str(e.orig):
386 raise HTTPException(
387 status_code=status.HTTP_400_BAD_REQUEST,
388 detail="Update would violate a unique constraint",
389 )
391 # Handle many-to-many relationships
392 if m2m_data:
393 upsert_many_to_many(db, new_entry.id, m2m_data, current_user.id)
394 db.commit()
395 db.refresh(new_entry)
397 if current_user.is_admin:
398 return new_entry
399 else:
400 return filter_out_non_owned(new_entry, current_user.id)
402 # ------------------------------------------------------- PUT ------------------------------------------------------
404 if "put" in allowed_actions:
406 @router.put("/{entry_id}", status_code=status.HTTP_200_OK, response_model=out_schema)
407 def update(
408 entry_id: int,
409 item: update_schema, # noqa
410 db: Session = Depends(database.get_db),
411 current_user: models.User = Depends(oauth2.get_current_user),
412 ):
413 """Update an entry by ID.
414 :param entry_id: The entry ID.
415 :param item: The updated data.
416 :param db: The database session.
417 :param current_user: The current user.
418 :returns: The updated entry.
419 :raises: HTTPException with a 404 status code if an entry is not found.
420 :raises: HTTPException with a 403 status code if not authorised to perform the requested action.
421 :raises: HTTPException with a 400 status code if no field is provided for the update."""
423 # Check if admin rights are needed
424 check_admin(current_user)
426 # Check settings if a check function is provided
427 if check_settings:
428 check_settings(db, item)
430 # Get the entry to update
431 entry = db.query(table_model).filter(table_model.id == entry_id).first()
432 if not entry:
433 raise NOT_FOUND_EXCEPTION
435 # Ensure that the user is authorised to modify this entry
436 check_ownership(entry, current_user)
438 # Extract the item data
439 item_dict = item.model_dump(exclude_unset=True)
441 if not item_dict:
442 raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="No fields provided for update")
444 # Separate main fields from many-to-many fields
445 main_data = item_dict.copy()
446 m2m_data = {}
448 if many_to_many_fields:
449 for field_name in many_to_many_fields.keys():
450 if field_name in main_data:
451 m2m_data[field_name] = main_data.pop(field_name)
453 # Apply transform to the merged data (existing entry + updates)
454 if transform:
455 entry_data = {c.name: getattr(entry, c.name) for c in entry.__table__.columns}
456 main_data.update(transform(main_data, db, entry_data))
458 # Update the record
459 for field, value in main_data.items():
460 if isinstance(value, dict):
461 for k, v in value.items():
462 setattr(getattr(entry, field), k, v)
463 else:
464 setattr(entry, field, value)
466 # Handle many-to-many relationships
467 if m2m_data:
468 upsert_many_to_many(db, entry_id, m2m_data, current_user.id, True)
470 try:
471 db.commit()
472 db.refresh(entry)
473 except IntegrityError as e:
474 db.rollback()
475 if "duplicate key value violates unique constraint" in str(e.orig):
476 raise HTTPException(
477 status_code=status.HTTP_400_BAD_REQUEST,
478 detail="Update would violate a unique constraint",
479 )
481 # Return the updated entry
482 if current_user.is_admin:
483 return entry
484 else:
485 return filter_out_non_owned(entry, current_user.id)
487 # ----------------------------------------------------- DELETE -----------------------------------------------------
489 if "delete" in allowed_actions:
491 @router.delete("/{entry_id}", status_code=status.HTTP_204_NO_CONTENT)
492 def delete(
493 entry_id: int,
494 db: Session = Depends(database.get_db),
495 current_user: models.User = Depends(oauth2.get_current_user),
496 ):
497 """Delete an entry by ID.
498 :param entry_id: The entry ID.
499 :param db: The database session.
500 :param current_user: The current user.
501 :returns: Dict with a deletion status message.
502 :raises: HTTPException with a 404 status code if an entry is not found.
503 :raises: HTTPException with a 403 status code if not authorised to perform the requested action."""
505 # Check if admin rights are needed
506 check_admin(current_user)
508 # Get the entry to delete
509 query = db.query(table_model).filter(table_model.id == entry_id)
510 entry = query.first()
511 if not entry:
512 raise NOT_FOUND_EXCEPTION
514 # Ensure that the user is authorised to delete this entry
515 check_ownership(entry, current_user)
517 # Delete many-to-many relationships first if they exist
518 if many_to_many_fields:
519 for field_name, m2m_config in many_to_many_fields.items():
520 association_table = m2m_config["table"]
521 local_key = m2m_config["local_key"]
523 db.execute(association_table.delete().where(getattr(association_table.c, local_key) == entry_id))
525 query.delete(synchronize_session=False)
526 db.commit()
528 return router