diff --git a/docs/core/web.rst b/docs/core/web.rst index 4e2e7e58dd..e4fdff9ebf 100644 --- a/docs/core/web.rst +++ b/docs/core/web.rst @@ -135,7 +135,138 @@ For example:: rest_endpoints=RestEndpointConfig(), ) - module = Module(name="tests.users") + module = Module( + name="tests.users", + resources=[user_resource_config], + ) + + +Resource REST Endpoint with Parent +---------------------------------- + +REST endpoints can also include a parent/child relationship with the resource. This is achieved using the +:class:`RestParentLink ` +attribute on the RestEndpointConfig. + +Example config:: + + from typing import Annotated + from superdesk.core.module import Module + from superdesk.core.resources import ( + ResourceConfig, + ResourceModel, + RestEndpointConfig, + RestParentLink, + ) + from superdesk.core.resources.validators import ( + validate_data_relation_async, + ) + + # 1. Define parent resource and config + class Company(ResourceModel): + name: str + + company_resource_config = ResourceConfig( + name="companies", + data_class=Company, + rest_endpoints=RestEndpointConfig() + ) + + # 2. Define child resource and config + class User(ResourceModel): + first_name: str + last_name: str + + # 2a. Include a field that references the parent + company: Annotated[ + str, + validate_data_relation_async( + company_resource_config.name, + ), + ] + + user_resource_config = ResourceConfig( + name="users", + data_class=User, + rest_endpoints=RestEndpointConfig( + + # 2b. Include a link to Company as a parent resource + parent_links=[ + RestParentLink( + resource_name=company_resource_config.name, + model_id_field="company", + ), + ], + ), + ) + + # 3. Register the resources with a module + module = Module( + name="tests.users", + resources=[ + company_resource_config, + user_resource_config, + ], + ) + + +The above example exposes the following URLs: + +* /api/companies +* /api/companies/```` +* /api/companies/````/users +* /api/companies/````/users/```` + +As you can see the ``users`` endpoints are prefixed with ``/api/company//``. + +This provides the following functionality: + +* Validation that a Company must exist for the user +* Populates the ``company`` field of a User with the ID from the URL +* When searching for users, will only provide users for the specific company provided in the URL of the request + +For example:: + + async def test_users(): + # Create the parent Company + response = await client.post( + "/api/company", + json={"name": "Sourcefabric"} + ) + + # Retrieve the Company ID from the response + company_id = (await response.get_json())[0] + + # Attemps to create a user with non-existing company + # responds with a 404 - NotFound error + response = await client.post( + f"/api/company/blah_blah/users", + json={"first_name": "Monkey", "last_name": "Mania"} + ) + assert response.status_code == 404 + + # Create the new User + # Notice the ``company_id`` is used in the URL + response = await client.post( + f"/api/company/{company_id}/users", + json={"first_name": "Monkey", "last_name": "Mania"} + ) + user_id = (await response.get_json())[0] + + # Retrieve the new user + response = await client.get( + f"/api/company/{company_id}/users/{user_id}" + ) + user_dict = await response.get_json() + assert user_dict["company"] == company_id + + # Retrieve all company users + response = await client.get( + f"/api/company/{company_id}/users" + ) + users_dict = (await response.get_json())["_items"] + assert len(users_dict) == 1 + assert users_dict[0]["_id"] == user_id Validation @@ -240,6 +371,11 @@ API References :members: :undoc-members: +.. autoclass:: superdesk.core.resources.resource_rest_endpoints.RestParentLink + :member-order: bysource + :members: + :undoc-members: + .. autoclass:: superdesk.core.resources.resource_rest_endpoints.ResourceRestEndpoints :member-order: bysource :members: diff --git a/pyproject.toml b/pyproject.toml index 4d7f66b06d..7944e13c2f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,3 +21,4 @@ exclude = ''' testpaths = ["tests", "superdesk", "apps", "content_api"] python_files = "*_test.py *_tests.py test_*.py tests_*.py tests.py test.py" asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" diff --git a/superdesk/core/elastic/resources.py b/superdesk/core/elastic/resources.py index f8e5229ab1..d944e1f7fe 100644 --- a/superdesk/core/elastic/resources.py +++ b/superdesk/core/elastic/resources.py @@ -88,12 +88,13 @@ def register_resource_config( client_config = ElasticClientConfig.create_from_dict( self.app.wsgi.config, prefix=resource_config.prefix or "ELASTICSEARCH", freeze=False ) - client_config.index += f"_{resource_name}" + source_name = self.app.resources.get_config(resource_name).datasource_name or resource_name + client_config.index += f"_{source_name}" client_config.set_frozen(True) - self._resource_clients[resource_name] = ElasticResourceClient(resource_name, client_config, resource_config) + self._resource_clients[resource_name] = ElasticResourceClient(source_name, client_config, resource_config) self._resource_async_clients[resource_name] = ElasticResourceAsyncClient( - resource_name, client_config, resource_config + source_name, client_config, resource_config ) def get_client(self, resource_name) -> ElasticResourceClient: diff --git a/superdesk/core/mongo.py b/superdesk/core/mongo.py index e40eeabff0..dcaacd0687 100644 --- a/superdesk/core/mongo.py +++ b/superdesk/core/mongo.py @@ -218,7 +218,8 @@ def get_all_resource_configs(self) -> Dict[str, MongoResourceConfig]: return deepcopy(self._resource_configs) def get_collection_name(self, resource_name: str, versioning: bool = False) -> str: - return resource_name if not versioning else f"{resource_name}_versions" + source_name = self.app.resources.get_config(resource_name).datasource_name or resource_name + return source_name if not versioning else f"{source_name}_versions" def reset_all_async_connections(self): for client, _db in self._mongo_clients_async.values(): @@ -265,7 +266,7 @@ def get_client(self, resource_name: str, versioning: bool = False) -> Tuple[Mong if not self._mongo_clients.get(mongo_config.prefix): client_config, dbname = get_mongo_client_config(self.app.wsgi.config, mongo_config.prefix) client: MongoClient = MongoClient(**client_config) - db = client.get_database(self.get_collection_name(dbname, versioning)) + db = client.get_database(dbname if not versioning else f"{dbname}_versions") self._mongo_clients[mongo_config.prefix] = (client, db) return self._mongo_clients[mongo_config.prefix] @@ -388,7 +389,7 @@ def get_client_async( if not self._mongo_clients_async.get(mongo_config.prefix): client_config, dbname = get_mongo_client_config(self.app.wsgi.config, mongo_config.prefix) client = AsyncIOMotorClient(**client_config) - db = client.get_database(self.get_collection_name(dbname, versioning)) + db = client.get_database(dbname if not versioning else f"{dbname}_versions") self._mongo_clients_async[mongo_config.prefix] = (client, db) return self._mongo_clients_async[mongo_config.prefix] diff --git a/superdesk/core/resources/__init__.py b/superdesk/core/resources/__init__.py index 6073660210..94c74b3e97 100644 --- a/superdesk/core/resources/__init__.py +++ b/superdesk/core/resources/__init__.py @@ -9,7 +9,7 @@ # at https://www.sourcefabric.org/superdesk/license from .model import Resources, ResourceModel, ResourceModelWithObjectId, ModelWithVersions, ResourceConfig, dataclass -from .resource_rest_endpoints import RestEndpointConfig +from .resource_rest_endpoints import RestEndpointConfig, RestParentLink, get_id_url_type from .service import AsyncResourceService, AsyncCacheableService from ..mongo import MongoResourceConfig, MongoIndexOptions from ..elastic.resources import ElasticResourceConfig @@ -23,6 +23,8 @@ "dataclass", "fields", "RestEndpointConfig", + "RestParentLink", + "get_id_url_type", "AsyncResourceService", "AsyncCacheableService", "MongoResourceConfig", diff --git a/superdesk/core/resources/model.py b/superdesk/core/resources/model.py index 571de1a1da..73ae419329 100644 --- a/superdesk/core/resources/model.py +++ b/superdesk/core/resources/model.py @@ -264,6 +264,9 @@ class ResourceConfig: #: Optional sorting for this resource default_sort: SortListParam | None = None + #: Optionally override the name used for the MongoDB/Elastic sources + datasource_name: str | None = None + class Resources: """A high level resource class used to manage all resources in the system""" @@ -295,6 +298,8 @@ def register(self, config: ResourceConfig): self._resource_configs[config.name] = config config.data_class.model_resource_name = config.name + if not config.datasource_name: + config.datasource_name = config.name mongo_config = config.mongo or MongoResourceConfig() if config.versioning: diff --git a/superdesk/core/resources/resource_rest_endpoints.py b/superdesk/core/resources/resource_rest_endpoints.py index 8787035f68..65747e8035 100644 --- a/superdesk/core/resources/resource_rest_endpoints.py +++ b/superdesk/core/resources/resource_rest_endpoints.py @@ -16,11 +16,12 @@ from eve.utils import querydef from typing_extensions import override from werkzeug.datastructures import MultiDict +from bson import ObjectId +from superdesk.core import json from superdesk.core.app import get_current_async_app from superdesk.core.types import SearchRequest, SearchArgs, VersionParam from superdesk.errors import SuperdeskApiError -from superdesk.core.types import SearchRequest, SearchArgs from ..web.types import HTTP_METHOD, Request, Response, RestGetResponse from ..web.rest_endpoints import RestEndpoints, ItemRequestViewArgs @@ -30,6 +31,31 @@ from .utils import resource_uses_objectid_for_id +@dataclass +class RestParentLink: + #: Name of the resource this parent link belongs to + resource_name: str + + #: Field used to store the resource ID in the child resource, defaults to ``resource_name`` + model_id_field: str | None = None + + #: Name of the URL argument in the route, defaults to ``model_id_field`` + url_arg_name: str | None = None + + #: ID Field of the parent used when searching for parent item resource, defaults to ``model_id_field`` + parent_id_field: str = "_id" + + def get_model_id_field(self) -> str: + """Get the ID Field for the local model used to store the reference to the parent model""" + + return self.model_id_field or self.resource_name + + def get_url_arg_name(self) -> str: + """Get the name of hte URL argument used in the route""" + + return self.url_arg_name or self.model_id_field or self.resource_name + + @dataclass class RestEndpointConfig: #: Optional list of resource level methods, defaults to ["GET", "POST"] @@ -44,6 +70,13 @@ class RestEndpointConfig: #: Optionally set a custom URL ID param syntax for item routes id_param_type: Optional[str] = None + #: Optionally set a custom URL for routes, defaults to ``resource_name`` + url: str | None = None + + #: Optionally assign parent resource(s) for this resource (parent/child relationship) + #: This will prepend this resources URL with the URL of the parent resource item + parent_links: list[RestParentLink] | None = None + def get_id_url_type(data_class: type[ResourceModel]) -> str: """Get the URL param type for the ID field for route registration""" @@ -74,16 +107,102 @@ def __init__( resource_config: ResourceConfig, endpoint_config: RestEndpointConfig, ): + self.resource_config = resource_config + self.endpoint_config = endpoint_config super().__init__( - url=resource_config.name, + url=endpoint_config.url or resource_config.name, name=resource_config.name, import_name=resource_config.__module__, resource_methods=endpoint_config.resource_methods, item_methods=endpoint_config.item_methods, id_param_type=endpoint_config.id_param_type or get_id_url_type(resource_config.data_class), ) - self.resource_config = resource_config - self.endpoint_config = endpoint_config + + def get_resource_url(self) -> str: + """Returns the URL for this resource + + If the resource has ``parent_links`` configured, these will be used to construct the URL + with the parent resources URL and item ID + """ + + if self.endpoint_config.parent_links is None: + return self.url + + app = get_current_async_app() + url = "" + for parent_link in self.endpoint_config.parent_links: + parent_config = app.resources.get_config(parent_link.resource_name) + id_param_type = 'regex("[\w,.:_-]+")' + parent_url = parent_link.resource_name + + if parent_config.rest_endpoints is not None: + if parent_config.rest_endpoints.url: + parent_url = parent_config.rest_endpoints.url + + if parent_config.rest_endpoints.id_param_type: + id_param_type = parent_config.rest_endpoints.id_param_type + else: + id_param_type = get_id_url_type(parent_config.data_class) + + arg_name = parent_link.get_url_arg_name() + url_prefix = f"{parent_url}/<{id_param_type}:{arg_name}>" + url += url_prefix + "/" + + return url + self.url + + def get_item_url(self, arg_name: str = "item_id") -> str: + """Returns the URL for an item of this resource + + :param arg_name: The name of the URL argument to use for the resource item URL + :return: The URL for an item of this resource + """ + + return f"{self.get_resource_url()}/<{self.id_param_type}:{arg_name}>" + + async def get_parent_items(self, request: Request) -> dict[str, dict]: + """Returns a dictionary of resource name to item for configured parent links + + :return: A dictionary, with the key being the resource name and value being the parent item + :raises SuperdeskApiError.badRequestError: If a parent item is not found + """ + + if self.endpoint_config.parent_links is None: + return {} + + items: dict[str, dict] = {} + for parent_link in self.endpoint_config.parent_links: + service = get_current_async_app().resources.get_resource_service(parent_link.resource_name) + item_id = request.get_view_args(parent_link.get_url_arg_name()) + if not item_id: + raise SuperdeskApiError.badRequestError("Parent resource ID not provided in URL") + item = await service.find_one_raw(use_mongo=True, version=None, **{parent_link.parent_id_field: item_id}) + if not item: + raise SuperdeskApiError.notFoundError( + f"Parent resource {parent_link.resource_name} with ID '{item_id}' not found" + ) + items[parent_link.resource_name] = item + + return items + + def construct_parent_item_lookup(self, request: Request) -> dict: + """Prefills a MongoDB query with the parent attributes from the request + + This is used to filter items of this resource to make sure they belong to all parent item(s). + + :param request: The request object currently being processed + :return: A MongoDB query + """ + if self.endpoint_config.parent_links is None: + return {} + + lookup = {} + for parent_link in self.endpoint_config.parent_links: + service = get_current_async_app().resources.get_resource_service(parent_link.resource_name) + item_id: str | ObjectId | None = request.get_view_args(parent_link.get_url_arg_name()) + if service.id_uses_objectid(): + item_id = ObjectId(item_id) + lookup[parent_link.get_model_id_field()] = item_id + return lookup @property def service(self): @@ -98,6 +217,9 @@ async def get_item( ) -> Response: """Processes a get single item request""" + await self.get_parent_items(request) + service = self.service + if params.version == "all": items, count = await self.service.get_all_item_versions(args.item_id, params.max_results, params.page) response = RestGetResponse( @@ -118,8 +240,12 @@ async def get_item( ), ) return Response(response, 200, [("X-Total-Count", count)]) - - item = await self.service.find_by_id_raw(args.item_id, params.version) + elif self.endpoint_config.parent_links: + lookup = self.construct_parent_item_lookup(request) + lookup["_id"] = args.item_id if not service.id_uses_objectid() else ObjectId(args.item_id) + item = await service.find_one_raw(use_mongo=True, version=params.version, **lookup) + else: + item = await service.find_by_id_raw(args.item_id, params.version) if not item: raise SuperdeskApiError.notFoundError( @@ -135,6 +261,7 @@ async def get_item( async def create_item(self, request: Request) -> Response: """Processes a create item request""" + parent_items = await self.get_parent_items(request) service = self.service payload = await request.get_json() @@ -150,6 +277,12 @@ async def create_item(self, request: Request) -> Response: try: if "_id" not in value: value["_id"] = service.generate_id() + + for parent_link in self.endpoint_config.parent_links or []: + parent_item = parent_items.get(parent_link.resource_name) + if parent_item is not None: + value[parent_link.get_model_id_field()] = parent_item[parent_link.parent_id_field] + model_instance = self.resource_config.data_class.model_validate(value) model_instances.append(model_instance) except ValidationError as validation_error: @@ -166,6 +299,7 @@ async def update_item( ) -> Response: """Processes an update item request""" + await self.get_parent_items(request) payload = await request.get_json() if_match = request.get_header("If-Match") @@ -187,6 +321,7 @@ async def update_item( async def delete_item(self, args: ItemRequestViewArgs, params: None, request: Request) -> Response: """Processes a delete item request""" + await self.get_parent_items(request) service = self.service original = await service.find_by_id(args.item_id) @@ -212,6 +347,18 @@ async def search_items( ) -> Response: """Processes a search request""" + await self.get_parent_items(request) + + if len(self.endpoint_config.parent_links or []): + if not isinstance(params.where, dict): + if params.where is None: + params.where = {} + elif isinstance(params.where, str): + params.where = cast(dict, json.loads(params.where)) + + lookup = self.construct_parent_item_lookup(request) + params.where.update(lookup) + params.args = cast(SearchArgs, params.model_extra) cursor = await self.service.find(params) count = await cursor.count() diff --git a/superdesk/core/resources/service.py b/superdesk/core/resources/service.py index 431f56b544..bdf70b5f42 100644 --- a/superdesk/core/resources/service.py +++ b/superdesk/core/resources/service.py @@ -119,17 +119,23 @@ def get_model_instance_from_dict(self, data: Dict[str, Any]) -> ResourceModelTyp data.pop("_type", None) return cast(ResourceModelType, self.config.data_class.model_validate(data)) - async def find_one(self, version: int | None = None, **lookup) -> Optional[ResourceModelType]: + async def find_one_raw(self, use_mongo: bool = False, version: int | None = None, **lookup) -> dict | None: """Find a resource by ID + :param use_mongo: If ``True`` will force use mongo, else will attempt elastic first :param version: Optional version to get :param lookup: Dictionary of key/value pairs used to find the document :return: ``None`` if resource not found, otherwise an instance of ``ResourceModel`` for this resource """ + item = None try: - item = await self.elastic.find_one(**lookup) + if not use_mongo: + item = await self.elastic.find_one(**lookup) except KeyError: + pass + + if use_mongo or item is None: item = await self.mongo.find_one(lookup) if item is None: @@ -137,7 +143,21 @@ async def find_one(self, version: int | None = None, **lookup) -> Optional[Resou elif version is not None: item = await self.get_item_version(item, version) - return self.get_model_instance_from_dict(item) + return item + + async def find_one( + self, use_mongo: bool = False, version: int | None = None, **lookup + ) -> Optional[ResourceModelType]: + """Find a resource by ID + + :param use_mongo: If ``True`` will force use mongo, else will attempt elastic first + :param version: Optional version to get + :param lookup: Dictionary of key/value pairs used to find the document + :return: ``None`` if resource not found, otherwise an instance of ``ResourceModel`` for this resource + """ + + item = await self.find_one_raw(use_mongo=use_mongo, version=version, **lookup) + return None if not item else self.get_model_instance_from_dict(item) async def find_by_id(self, item_id: Union[str, ObjectId]) -> Optional[ResourceModelType]: """Find a resource by ID diff --git a/superdesk/core/web/rest_endpoints.py b/superdesk/core/web/rest_endpoints.py index 614cdc57e0..ca7044b8d0 100644 --- a/superdesk/core/web/rest_endpoints.py +++ b/superdesk/core/web/rest_endpoints.py @@ -50,10 +50,11 @@ def __init__( self.item_methods = item_methods or ["GET", "PATCH", "DELETE"] self.id_param_type = id_param_type or "string" + resource_url = self.get_resource_url() if "GET" in self.resource_methods: self.endpoints.append( Endpoint( - url=self.url, + url=resource_url, name="resource_get", func=self.search_items, methods=["GET"], @@ -63,14 +64,14 @@ def __init__( if "POST" in self.resource_methods: self.endpoints.append( Endpoint( - url=self.url, + url=resource_url, name="resource_post", func=self.create_item, methods=["POST"], ) ) - item_url = f"{self.url}/<{self.id_param_type}:item_id>" + item_url = self.get_item_url() if "GET" in self.item_methods: self.endpoints.append( Endpoint( @@ -101,6 +102,12 @@ def __init__( ) ) + def get_resource_url(self): + return self.url + + def get_item_url(self, arg_name: str = "item_id"): + return f"{self.get_resource_url()}/<{self.id_param_type}:{arg_name}>" + async def get_item( self, args: ItemRequestViewArgs, diff --git a/superdesk/core/web/types.py b/superdesk/core/web/types.py index ca755f15b2..a15712f379 100644 --- a/superdesk/core/web/types.py +++ b/superdesk/core/web/types.py @@ -205,6 +205,12 @@ async def get_data(self) -> Union[bytes, str]: async def abort(self, code: int, *args: Any, **kwargs: Any) -> NoReturn: ... + def get_view_args(self, key: str) -> str | None: + ... + + def get_url_arg(self, key: str) -> str | None: + ... + class EndpointGroup: """Base class used for registering a group of endpoints""" diff --git a/superdesk/factory/app.py b/superdesk/factory/app.py index 0b0a638079..dc92c5239c 100644 --- a/superdesk/factory/app.py +++ b/superdesk/factory/app.py @@ -80,6 +80,12 @@ async def get_data(self) -> Union[bytes, str]: async def abort(self, code: int, *args: Any, **kwargs: Any) -> NoReturn: abort(code, *args, **kwargs) + def get_view_args(self, key: str) -> str | None: + return None if not self.request.view_args else self.request.view_args.get(key, None) + + def get_url_arg(self, key: str) -> str | None: + return self.request.args.get(key, None) + def set_error_handlers(app): """Set error handlers for the given application object. diff --git a/superdesk/tests/__init__.py b/superdesk/tests/__init__.py index da9b2b8435..14da8aa7cf 100644 --- a/superdesk/tests/__init__.py +++ b/superdesk/tests/__init__.py @@ -67,7 +67,7 @@ def get_mongo_uri(key, dbname): return "/".join([env_host, dbname]) -def update_config(conf): +def update_config(conf, auto_add_apps: bool = True): conf["ELASTICSEARCH_INDEX"] = "sptest" conf["MONGO_DBNAME"] = "sptests" conf["MONGO_URI"] = get_mongo_uri("MONGO_URI", "sptests") @@ -100,7 +100,8 @@ def update_config(conf): conf["MACROS_MODULE"] = "superdesk.macros" conf["DEFAULT_TIMEZONE"] = "Europe/Prague" conf["LEGAL_ARCHIVE"] = True - conf["INSTALLED_APPS"].extend(["planning", "superdesk.macros.imperial", "apps.rundowns"]) + if auto_add_apps: + conf["INSTALLED_APPS"].extend(["planning", "superdesk.macros.imperial", "apps.rundowns"]) # limit mongodb connections conf["MONGO_CONNECT"] = False @@ -176,7 +177,7 @@ async def drop_mongo(app): dbconn.drop_database(dbname) -def setup_config(config): +def setup_config(config, auto_add_apps: bool = True): app_abspath = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) app_config = Config(app_abspath) app_config.from_object("superdesk.default_settings") @@ -190,7 +191,7 @@ def setup_config(config): else: logger.warning("Can't find local settings") - update_config(app_config) + update_config(app_config, auto_add_apps) app_config.setdefault("INSTALLED_APPS", []) @@ -360,8 +361,8 @@ def inner(*a, **kw): use_snapshot.cache = {} # type: ignore -async def setup(context=None, config=None, app_factory=get_app, reset=False): - if not hasattr(setup, "app") or setup.reset or config: +async def setup(context=None, config=None, app_factory=get_app, reset=False, auto_add_apps: bool = True): + if not hasattr(setup, "app") or setup.reset or config: # type: ignore[attr-defined] if hasattr(setup, "app"): # Close all PyMongo Connections (new ones will be created with ``app_factory`` call) for key, val in setup.app.extensions["pymongo"].items(): @@ -370,10 +371,10 @@ async def setup(context=None, config=None, app_factory=get_app, reset=False): if getattr(setup.app, "async_app", None): setup.app.async_app.stop() - cfg = setup_config(config) - setup.app = app_factory(cfg) - setup.reset = reset - app = setup.app + cfg = setup_config(config, auto_add_apps) + setup.app = app_factory(cfg) # type: ignore[attr-defined] + setup.reset = reset # type: ignore[attr-defined] + app = setup.app # type: ignore[attr-defined] if context: context.app = app @@ -542,6 +543,9 @@ def setupApp(self): self.app_config = setup_config(self.app_config) self.app = SuperdeskAsyncApp(MockWSGI(config=self.app_config)) + self.startApp() + + def startApp(self): self.app.start() async def asyncSetUp(self): @@ -574,8 +578,19 @@ def model_instance_to_json(self, model_instance: ResourceModel): return model_instance.model_dump(by_alias=True, exclude_unset=True, mode="json") async def post(self, *args, **kwargs) -> Response: - if "json" in kwargs and isinstance(kwargs["json"], ResourceModel): - kwargs["json"] = self.model_instance_to_json(kwargs["json"]) + if "json" in kwargs: + if isinstance(kwargs["json"], ResourceModel): + kwargs["json"] = self.model_instance_to_json(kwargs["json"]) + elif isinstance(kwargs["json"], list): + kwargs["json"] = [ + self.model_instance_to_json(item) if isinstance(item, ResourceModel) else item + for item in kwargs["json"] + ] + elif isinstance(kwargs["json"], dict): + kwargs["json"] = { + key: self.model_instance_to_json(value) if isinstance(value, ResourceModel) else value + for key, value in kwargs["json"].items() + } return await super().post(*args, **kwargs) @@ -583,13 +598,19 @@ async def post(self, *args, **kwargs) -> Response: class AsyncFlaskTestCase(AsyncTestCase): async_app: SuperdeskAsyncApp app: SuperdeskApp + use_default_apps: bool = False async def asyncSetUp(self): if getattr(self, "async_app", None): self.async_app.stop() await self.async_app.elastic.stop() - await setup(self, config=self.app_config, reset=True) + if self.use_default_apps: + await setup(self, config=self.app_config, reset=True, auto_add_apps=True) + else: + self.app_config.setdefault("CORE_APPS", []) + self.app_config.setdefault("INSTALLED_APPS", []) + await setup(self, config=self.app_config, reset=True, auto_add_apps=False) self.async_app = self.app.async_app self.app.test_client_class = TestClient self.test_client = self.app.test_client() @@ -618,4 +639,5 @@ async def get_resource_etag(self, resource: str, item_id: str): return (await (await self.test_client.get(f"/api/{resource}/{item_id}")).get_json())["_etag"] -TestCase = AsyncFlaskTestCase +class TestCase(AsyncFlaskTestCase): + use_default_apps: bool = True diff --git a/tests/core/modules/company.py b/tests/core/modules/company.py new file mode 100644 index 0000000000..211f86fbae --- /dev/null +++ b/tests/core/modules/company.py @@ -0,0 +1,20 @@ +from superdesk.core.module import Module +from superdesk.core.resources import ResourceConfig, ResourceModel, AsyncResourceService, RestEndpointConfig + + +class CompanyResource(ResourceModel): + name: str + + +class CompanyService(AsyncResourceService[CompanyResource]): + resource_name = "companies" + + +companies_resource_config = ResourceConfig( + name="companies", + data_class=CompanyResource, + service=CompanyService, + rest_endpoints=RestEndpointConfig(), +) + +module = Module(name="tests.company", resources=[companies_resource_config]) diff --git a/tests/core/modules/topics.py b/tests/core/modules/topics.py new file mode 100644 index 0000000000..04722323ae --- /dev/null +++ b/tests/core/modules/topics.py @@ -0,0 +1,74 @@ +from typing import Annotated +from superdesk.core.module import Module +from superdesk.core.resources import ( + ResourceConfig, + ResourceModel, + AsyncResourceService, + RestEndpointConfig, + RestParentLink, +) +from superdesk.core.resources.validators import validate_data_relation_async + +from .users import user_model_config +from .company import companies_resource_config + + +class TopicFolder(ResourceModel): + name: str + section: str + + +class UserFolder(TopicFolder): + user: Annotated[str, validate_data_relation_async(user_model_config.name)] + + +class UserFolderService(AsyncResourceService[UserFolder]): + resource_name = "user_topic_folders" + + +user_folder_config = ResourceConfig( + name="user_topic_folders", + datasource_name="topic_folders", + data_class=UserFolder, + service=UserFolderService, + rest_endpoints=RestEndpointConfig( + parent_links=[ + RestParentLink( + resource_name=user_model_config.name, + model_id_field="user", + ) + ], + url="topic_folders", + ), +) + + +class CompanyFolder(TopicFolder): + company: str + + +class CompanyFolderService(AsyncResourceService[CompanyFolder]): + resource_name = "company_topic_folders" + + +company_folder_config = ResourceConfig( + name="company_topic_folders", + datasource_name="topic_folders", + data_class=CompanyFolder, + service=CompanyFolderService, + rest_endpoints=RestEndpointConfig( + parent_links=[ + RestParentLink( + resource_name=companies_resource_config.name, + model_id_field="company", + ) + ], + url="topic_folders", + ), +) + + +module = Module( + name="tests.multi_sources", + resources=[user_folder_config, company_folder_config], +) diff --git a/tests/core/mongo_test.py b/tests/core/mongo_test.py index 10f1c5ef8c..c34251cd3b 100644 --- a/tests/core/mongo_test.py +++ b/tests/core/mongo_test.py @@ -85,3 +85,26 @@ def test_init_indexes(self): # ``collation`` uses an ``bson.son.SON` instance, so use that for testing here self.assertEqual(indexes["combined_name_1"]["collation"].get("locale"), "en") self.assertEqual(indexes["combined_name_1"]["collation"].get("strength"), 1) + + +class MongoClientSourceTestCase(AsyncTestCase): + app_config = { + "MODULES": [ + "tests.core.modules.users", + "tests.core.modules.company", + "tests.core.modules.topics", + ] + } + + def test_mongo_collection_source(self): + user_db = self.app.mongo.get_db("user_topic_folders") + user_collection = user_db.get_collection("topic_folders") + self.assertEqual(self.app.mongo.get_collection("user_topic_folders"), user_collection) + self.assertEqual(user_collection.full_name, "sptests.topic_folders") + + company_db = self.app.mongo.get_db("company_topic_folders") + company_collection = company_db.get_collection("topic_folders") + self.assertEqual(self.app.mongo.get_collection("company_topic_folders"), company_collection) + self.assertEqual(company_collection.full_name, "sptests.topic_folders") + + self.assertEqual(user_collection, company_collection) diff --git a/tests/core/resource_parent_links_test.py b/tests/core/resource_parent_links_test.py new file mode 100644 index 0000000000..7735c603aa --- /dev/null +++ b/tests/core/resource_parent_links_test.py @@ -0,0 +1,138 @@ +from superdesk.core.resources.resource_rest_endpoints import ResourceRestEndpoints +from superdesk.tests import AsyncFlaskTestCase + +from .fixtures.users import john_doe, jane_doe + + +class ResourceParentLinksTestCase(AsyncFlaskTestCase): + app_config = { + "MODULES": [ + "tests.core.modules.users", + "tests.core.modules.company", + "tests.core.modules.topics", + ] + } + + async def test_parent_url_links(self): + user_collection = self.async_app.mongo.get_collection_async("user_topic_folders") + company_collection = self.async_app.mongo.get_collection_async("company_topic_folders") + test_user1 = john_doe() + test_user2 = jane_doe() + + # First add the 2 users we'll use for filtering/testing + response = await self.test_client.post("/api/users_async", json=[test_user1, test_user2]) + self.assertEqual(response.status_code, 201) + + # Make sure the folders resource is empty + self.assertEqual(await user_collection.count_documents({}), 0) + response = await self.test_client.get(f"/api/users_async/{test_user1.id}/topic_folders") + self.assertEqual(len((await response.get_json())["_items"]), 0) + + # Add a folder for each user + response = await self.test_client.post( + f"/api/users_async/{test_user1.id}/topic_folders", json=dict(name="Sports", section="wire") + ) + self.assertEqual(response.status_code, 201) + response = await self.test_client.post( + f"/api/users_async/{test_user2.id}/topic_folders", json=dict(name="Finance", section="agenda") + ) + self.assertEqual(response.status_code, 201) + + # Make sure all folders exist in the mongo collection + self.assertEqual(await user_collection.count_documents({}), 2) + # points to same collection as users folders, so it too should have 2 documents + self.assertEqual(await company_collection.count_documents({}), 2) + + # Test getting folders for User1 + response = await self.test_client.get(f"/api/users_async/{test_user1.id}/topic_folders") + self.assertEqual(response.status_code, 200) + data = await response.get_json() + self.assertEqual(len(data["_items"]), 1) + self.assertEqual(data["_meta"]["total"], 1) + self.assertDictContains(data["_items"][0], dict(user=test_user1.id, name="Sports", section="wire")) + + # Test getting folders for User2 + response = await self.test_client.get(f"/api/users_async/{test_user2.id}/topic_folders") + self.assertEqual(response.status_code, 200) + data = await response.get_json() + self.assertEqual(len(data["_items"]), 1) + self.assertEqual(data["_meta"]["total"], 1) + self.assertDictContains(data["_items"][0], dict(user=test_user2.id, name="Finance", section="agenda")) + + # Test searching folders for User1 + user1_folders_url = f"/api/users_async/{test_user1.id}/topic_folders" + response = await self.test_client.post(user1_folders_url, json=dict(name="Finance", section="agenda")) + self.assertEqual(response.status_code, 201) + + # Make sure there are 2 folders when not filtering + response = await self.test_client.get(user1_folders_url) + self.assertEqual(response.status_code, 200) + data = await response.get_json() + self.assertEqual(len(data["_items"]), 2) + + # Make sure there is only 1 folder when filtering + response = await self.test_client.get(user1_folders_url + '?where={"section":"wire"}') + self.assertEqual(response.status_code, 200) + data = await response.get_json() + self.assertEqual(len(data["_items"]), 1) + self.assertEqual(data["_meta"]["total"], 1) + self.assertDictContains(data["_items"][0], dict(user=test_user1.id, name="Sports", section="wire")) + + async def test_patch_and_delete(self): + # Create the user + test_user1 = john_doe() + response = await self.test_client.post("/api/users_async", json=test_user1) + self.assertEqual(response.status_code, 201) + + # Create the users folder + response = await self.test_client.post( + f"/api/users_async/{test_user1.id}/topic_folders", json=dict(name="Sports", section="wire") + ) + self.assertEqual(response.status_code, 201) + folder_id = (await response.get_json())[0] + + # Get the folder, so we can use it's etag + response = await self.test_client.get(f"/api/users_async/{test_user1.id}/topic_folders/{folder_id}") + folder = await response.get_json() + + # Update the users folder + response = await self.test_client.patch( + f"/api/users_async/{test_user1.id}/topic_folders/{folder_id}", + json=dict(name="Swimming"), + headers={"If-Match": folder["_etag"]}, + ) + self.assertEqual(response.status_code, 200) + + # Delete the users folder + response = await self.test_client.get(f"/api/users_async/{test_user1.id}/topic_folders/{folder_id}") + folder = await response.get_json() + response = await self.test_client.delete( + f"/api/users_async/{test_user1.id}/topic_folders/{folder_id}", headers={"If-Match": folder["_etag"]} + ) + self.assertEqual(response.status_code, 204) + + async def test_parent_link_validation(self): + test_user1 = john_doe() + + # Test request returns 404 when parent item does not exist in the DB + response = await self.test_client.post( + f"/api/users_async/{test_user1.id}/topic_folders", json=dict(name="Sports", section="wire") + ) + self.assertEqual(response.status_code, 404) + + # Now add the parent item, and test request returns 201 + response = await self.test_client.post("/api/users_async", json=test_user1) + self.assertEqual(response.status_code, 201) + response = await self.test_client.post( + f"/api/users_async/{test_user1.id}/topic_folders", json=dict(name="Sports", section="wire") + ) + self.assertEqual(response.status_code, 201) + + def test_generated_resource_url(self): + config = self.async_app.resources.get_config("user_topic_folders") + endpoint = ResourceRestEndpoints(config, config.rest_endpoints) + self.assertEqual(endpoint.get_resource_url(), 'users_async//topic_folders') + + config = self.async_app.resources.get_config("company_topic_folders") + endpoint = ResourceRestEndpoints(config, config.rest_endpoints) + self.assertEqual(endpoint.get_resource_url(), 'companies//topic_folders')