diff --git a/stix2/datastore/relational_db/database_backends/database_backend_base.py b/stix2/datastore/relational_db/database_backends/database_backend_base.py index 4cf46f1b..e5082451 100644 --- a/stix2/datastore/relational_db/database_backends/database_backend_base.py +++ b/stix2/datastore/relational_db/database_backends/database_backend_base.py @@ -18,9 +18,6 @@ def __init__(self, database_connection_url, force_recreate=False, **kwargs: Any) self.database_connection = create_engine(database_connection_url) - def _fk_pragma_on_connect(self): - self.database_connnection.execute('pragma foreign_keys=ON') - def _create_database(self): if self.database_exists: drop_database(self.database_connection_url) diff --git a/stix2/datastore/relational_db/database_backends/sqlite_backend.py b/stix2/datastore/relational_db/database_backends/sqlite_backend.py index ae5473bf..f8094a15 100644 --- a/stix2/datastore/relational_db/database_backends/sqlite_backend.py +++ b/stix2/datastore/relational_db/database_backends/sqlite_backend.py @@ -2,6 +2,7 @@ from typing import Any from sqlalchemy import TIMESTAMP, LargeBinary, Text +from sqlalchemy.engine import Engine from sqlalchemy import event from stix2.base import ( @@ -18,7 +19,11 @@ class SQLiteBackend(DatabaseBackend): def __init__(self, database_connection_url=default_database_connection_url, force_recreate=False, **kwargs: Any): super().__init__(database_connection_url, force_recreate=force_recreate, **kwargs) - event.listen(self.database_connection, 'connect', self._fk_pragma_on_connect) + set_sqlite_pragma(self) + + @event.listens_for(Engine, "connect") + def set_sqlite_pragma(self): + self.database_connection.execute("PRAGMA foreign_keys=ON") # ========================================================================= # sql type methods (overrides)