|
8 | 8 | logger = structlog.get_logger(__name__) |
9 | 9 |
|
10 | 10 |
|
11 | | -class UnknownConnectionError(Exception): |
12 | | - pass |
13 | | - |
14 | | - |
15 | | -def rough_dict_get(dct, sought, default=None): |
16 | | - """ |
17 | | - Like dct.get(sought), but any key containing sought will do. |
| 11 | +LOCAL_DB_CONN_HANDLE = "@noteable" |
| 12 | +LOCAL_DB_CONN_NAME = "Local Database" |
| 13 | +DUCKDB_LOCATION = "duckdb:///:memory:" |
18 | 14 |
|
19 | | - If there is a `@` in sought, seek each piece separately. |
20 | | - This lets `me@server` match `me:***@myserver/db` |
21 | | - """ |
22 | 15 |
|
23 | | - sought = sought.split("@") |
24 | | - for key, val in dct.items(): |
25 | | - if not any(s.lower() not in key.lower() for s in sought): |
26 | | - return val |
27 | | - return default |
| 16 | +class UnknownConnectionError(Exception): |
| 17 | + pass |
28 | 18 |
|
29 | 19 |
|
30 | | -class Connection(object): |
| 20 | +class Connection: |
31 | 21 | current = None |
32 | 22 | connections: Dict[str, 'Connection'] = {} |
33 | 23 | bootstrapping_failures: Dict[str, str] = {} |
@@ -107,23 +97,30 @@ def __init__(self, connect_str=None, name=None, human_name=None, **create_engine |
107 | 97 | self.metadata = sqlalchemy.MetaData(bind=self._engine) |
108 | 98 | self.name = name or self.assign_name(self._engine) |
109 | 99 | self.human_name = human_name |
110 | | - self._session = None |
| 100 | + self._sqla_connection = None |
111 | 101 | self.connections[name or repr(self.metadata.bind.url)] = self |
112 | 102 |
|
113 | 103 | Connection.current = self |
114 | 104 |
|
115 | 105 | @property |
116 | | - def session(self) -> sqlalchemy.engine.base.Connection: |
117 | | - """Lazily connect to the database. |
| 106 | + def engine(self) -> sqlalchemy.engine.base.Engine: |
| 107 | + return self._engine |
118 | 108 |
|
119 | | - Despite the name, this is a SQLA Connection, not a Session. And 'Connection' |
120 | | - is highly overused term around here. |
121 | | - """ |
| 109 | + @property |
| 110 | + def sqla_connection(self) -> sqlalchemy.engine.base.Connection: |
| 111 | + """Lazily connect to the database. Return a SQLA Connection object, or die trying.""" |
| 112 | + |
| 113 | + if not self._sqla_connection: |
| 114 | + self._sqla_connection = self._engine.connect() |
122 | 115 |
|
123 | | - if not self._session: |
124 | | - self._session = self._engine.connect() |
| 116 | + return self._sqla_connection |
125 | 117 |
|
126 | | - return self._session |
| 118 | + def reset_connection_pool(self): |
| 119 | + """Reset the SQLA connection pool, such as after an exception suspected to indicate |
| 120 | + a broken connection has been raised. |
| 121 | + """ |
| 122 | + self._engine.dispose() |
| 123 | + self._sqla_connection = None |
127 | 124 |
|
128 | 125 | @classmethod |
129 | 126 | def set( |
@@ -164,14 +161,23 @@ def connection_list(cls): |
164 | 161 | result.append(template.format(engine_url.__repr__())) |
165 | 162 | return "\n".join(result) |
166 | 163 |
|
| 164 | + @classmethod |
| 165 | + def find(cls, name: str) -> Optional['Connection']: |
| 166 | + """Find a connection by SQL cell handle or by human assigned name""" |
| 167 | + # TODO: Capt. Obvious says to double-register the instance by both of these keys |
| 168 | + # to then be able to do lookups properly in this dict? |
| 169 | + for c in cls.connections.values(): |
| 170 | + if c.name == name or c.human_name == name: |
| 171 | + return c |
| 172 | + |
167 | 173 | @classmethod |
168 | 174 | def get_engine(cls, name: str) -> Optional[Engine]: |
169 | 175 | """Return the SQLAlchemy Engine given either the sql_cell_handle or |
170 | 176 | end-user assigned name for the connection. |
171 | 177 | """ |
172 | | - for c in cls.connections.values(): |
173 | | - if c.name == name or c.human_name == name: |
174 | | - return c._engine |
| 178 | + maybe_conn = cls.find(name) |
| 179 | + if maybe_conn: |
| 180 | + return maybe_conn.engine |
175 | 181 |
|
176 | 182 | @classmethod |
177 | 183 | def add_bootstrapping_failure(cls, name: str, human_name: Optional[str], error_message: str): |
@@ -204,7 +210,56 @@ def _close(cls, descriptor): |
204 | 210 | ) |
205 | 211 | cls.connections.pop(conn.name, None) |
206 | 212 | cls.connections.pop(str(conn.metadata.bind.url), None) |
207 | | - conn.session.close() |
| 213 | + conn.sqla_connection.close() |
208 | 214 |
|
209 | 215 | def close(self): |
210 | 216 | self.__class__._close(self) |
| 217 | + |
| 218 | + |
| 219 | +def rough_dict_get(dct, sought, default=None): |
| 220 | + """ |
| 221 | + Like dct.get(sought), but any key containing sought will do. |
| 222 | +
|
| 223 | + If there is a `@` in sought, seek each piece separately. |
| 224 | + This lets `me@server` match `me:***@myserver/db` |
| 225 | + """ |
| 226 | + |
| 227 | + sought = sought.split("@") |
| 228 | + for key, val in dct.items(): |
| 229 | + if not any(s.lower() not in key.lower() for s in sought): |
| 230 | + return val |
| 231 | + return default |
| 232 | + |
| 233 | + |
| 234 | +def get_db_connection(name_or_handle: str) -> Optional[Connection]: |
| 235 | + """Return the noteable.sql.connection.Connection corresponding to the requested |
| 236 | + datasource a name or handle. |
| 237 | +
|
| 238 | + Will return None if the given handle isn't present in |
| 239 | + the connections dict already (created after this kernel was launched?) |
| 240 | + """ |
| 241 | + return Connection.find(name_or_handle) |
| 242 | + |
| 243 | + |
| 244 | +def get_sqla_connection(name_or_handle: str) -> Optional[sqlalchemy.engine.base.Connection]: |
| 245 | + """Return a SQLAlchemy connection given a name or handle |
| 246 | + Returns None if cannot find by this string. |
| 247 | + """ |
| 248 | + nconn = get_db_connection(name_or_handle) |
| 249 | + if nconn: |
| 250 | + return nconn.sqla_connection |
| 251 | + |
| 252 | + |
| 253 | +def get_sqla_engine(name_or_handle: str) -> Optional[Engine]: |
| 254 | + """Return a SQLAlchemy Engine given a name or handle. |
| 255 | + Returns None if cannot find by this string. |
| 256 | + """ |
| 257 | + return Connection.get_engine(name_or_handle) |
| 258 | + |
| 259 | + |
| 260 | +def bootstrap_duckdb(): |
| 261 | + Connection.set( |
| 262 | + DUCKDB_LOCATION, |
| 263 | + human_name=LOCAL_DB_CONN_NAME, |
| 264 | + name=LOCAL_DB_CONN_HANDLE, |
| 265 | + ) |
0 commit comments