Source code for fluidml.storage.mongo_db_store

import functools
import logging
import pickle
from typing import Any, Dict, Optional

import mongoengine as me

from fluidml.storage.base import ResultsStore, StoreContext

logger = logging.getLogger(__name__)


def connection(func):
    """Decorator to handle connecting and disconnecting to/from db"""

    @functools.wraps(func)
    def wrapper_connect_disconnect_db(self, *args, **kwargs):
        me.connect(db=self._db, host=self._host)
        result = func(self, *args, **kwargs)
        me.disconnect()
        return result

    # support for calling method without decorator logic
    wrapper_connect_disconnect_db.no_connection = func
    return wrapper_connect_disconnect_db


class ResultObject(me.EmbeddedDocument):
    obj = me.FileField()


[docs]class MongoDBStore(ResultsStore): """A mongo db result store implementation. Args: db: The database name to be used. collection_name: The name of the collection. host: The host name of the :program: `mongod` instance. """ def __init__(self, db: str, collection_name: Optional[str] = None, host: Optional[str] = None): super().__init__() self._host = host self._db = db self._collection_name = collection_name
[docs] @connection def load( self, name: str, task_name: str, task_unique_config: Dict, lazy: bool = False, ) -> Optional[Any]: """Query method to load an object based on its name, task_name and task_config if it exists""" task_result_cls = self._get_task_result_class() # try to get query run document based on task name and task unique config try: task_result = task_result_cls.objects(name=task_name).get(unique_config__lte=task_unique_config) except me.DoesNotExist: return None # try to query obj from results DictField try: result_obj = task_result.results[name] obj = pickle.loads(result_obj.obj.read()) except KeyError: logger.warning(f'"{name}" could not be found in store.') return None return obj
[docs] @connection def save( self, obj: Any, name: str, type_: str, task_name: str, task_unique_config: Dict, **kwargs, ): """Method to save/update any artifact""" task_result_cls = self._get_task_result_class() # try to get query run document based on task name and task unique config try: task_result = task_result_cls.objects(name=task_name).get(unique_config=task_unique_config) except me.DoesNotExist: # create new document if query was not successful task_result = task_result_cls(name=task_name, unique_config=task_unique_config) # store object in document and save the document result_obj = ResultObject(obj=pickle.dumps(obj)) task_result.results[name] = result_obj task_result.save()
[docs] @connection def delete(self, name: str, task_name: str, task_unique_config: Dict): """Query method to delete an object based on its name, task_name and task_config if it exists""" task_result_cls = self._get_task_result_class() # try to get query run document based on task name and task unique config try: task_result = task_result_cls.objects(name=task_name).get(unique_config__lte=task_unique_config) except me.DoesNotExist: logger.warning( f'"{name}" could not be deleted. ' f'No document for task "{task_name}" and the provided unique_config exists.' ) return None # try to delete obj from results DictField try: del task_result.results[name] task_result.save() except KeyError: logger.warning(f'"{name}" could not be deleted from store since it was not found.')
[docs] @connection def delete_run(self, task_name: str, task_unique_config: Dict): """Query method to delete a run document based on its task_name and task_config if it exists""" task_result_cls = self._get_task_result_class() # try to get query run document based on task name and task unique config try: task_result = task_result_cls.objects(name=task_name).get(unique_config__lte=task_unique_config) except me.DoesNotExist: logger.warning(f'No document for task "{task_name}" and the provided unique_config exists.') return None # delete document task_result.delete()
[docs] def get_context(self, task_name: str, task_unique_config: Dict) -> StoreContext: pass
def _get_task_result_class(self): # Hack to set the collection name dynamically from user input # Default is the document class name lower-cased, here: "task_result" class TaskResult(me.DynamicDocument): name = me.StringField() unique_config = me.DictField() results = me.DictField(me.EmbeddedDocumentField(ResultObject)) if self._collection_name is not None: meta = {"collection": self._collection_name} return TaskResult
def main(): store = MongoDBStore("test") task_1_config = {"param_a": 23, "param_b": 55} task_2_config = {"param_a": 23, "param_b": 55, "param_c": 88} obj1 = {"Result": 5} obj2 = "hallo" task_1 = "task_1" task_2 = "task_2" store.save( obj=obj1, name="obj1", type_="a", task_name=task_1, task_unique_config=task_1_config, ) store.save( obj=obj2, name="obj2", type_="b", task_name=task_2, task_unique_config=task_2_config, ) # store.delete(name='obj2', task_name=task_name, task_unique_config=task_config) result1 = store.load(name="obj1", task_name=task_1, task_unique_config=task_1_config) result2 = store.load(name="obj2", task_name=task_2, task_unique_config=task_2_config) result3 = store.load(name="obj1", task_name=task_1, task_unique_config=task_2_config) print(result1) print(result2) print(result3) if __name__ == "__main__": main()