# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Dict
import sqlite3
from deeppavlov.core.models.estimator import Estimator
from deeppavlov.core.common.registry import register
from deeppavlov.core.common.log import get_logger
log = get_logger(__name__)
[docs]@register('sqlite_database')
class Sqlite3Database(Estimator):
"""
Loads and trains sqlite table of any items (with name ``table_name``
and path ``save_path``).
Primary (unique) keys must be specified, all other keys are infered from data.
Batch here is a list of dictionaries, where each dictionary corresponds to an item.
If an item doesn't contain values for all keys, then missing values will be stored
with ``unknown_value``.
Parameters:
save_path: sqlite database path.
table_name: name of the sqlite table.
primary_keys: list of table primary keys' names.
keys: all table keys' names.
unknown_value: value assigned to missing item values.
**kwargs: parameters passed to parent :class:`~deeppavlov.core.models.estimator.Estimator` class.
"""
def __init__(self, save_path: str,
table_name: str,
primary_keys: List[str],
keys: List[str] = None,
unknown_value: str = 'UNK',
*args, **kwargs) -> None:
super().__init__(save_path=save_path, *args, **kwargs)
self.primary_keys = primary_keys
if not self.primary_keys:
raise ValueError("Primary keys list can't be empty")
self.tname = table_name
self.keys = keys
self.unknown_value = unknown_value
self.conn = sqlite3.connect(str(self.save_path), check_same_thread=False)
self.cursor = self.conn.cursor()
if self._check_if_table_exists():
log.info("Loading database from {}.".format(self.save_path))
if not self.keys:
self.keys = self._get_keys()
else:
log.info("Initializing empty database on {}.".format(self.save_path))
def __call__(self, batch: List[Dict],
order_by: str = None,
ascending: bool = False) -> List[List[Dict]]:
order = 'ASC' if ascending else 'DESC'
if not self._check_if_table_exists():
log.warn("Database is empty, call fit() before using.")
return [[] for i in range(len(batch))]
return [self._search(b, order_by=order_by, order=order) for b in batch]
def _check_if_table_exists(self):
self.cursor.execute("SELECT name FROM sqlite_master"
" WHERE type='table'"
" AND name='{}';".format(self.tname))
return bool(self.cursor.fetchall())
def _search(self, kv, order_by, order):
if not kv:
# get all table content
if order_by is not None:
self.cursor.execute("SELECT * FROM {}".format(self.tname) +
" ORDER BY {} {}".format(order_by, order))
else:
self.cursor.execute("SELECT * FROM {}".format(self.tname))
else:
keys = list(kv.keys())
values = [kv[k] for k in keys]
where_expr = ' AND '.join('{}=?'.format(k) for k in keys)
if order_by is not None:
self.cursor.execute("SELECT * FROM {}".format(self.tname) +
" WHERE {}".format(where_expr) +
" ORDER BY {} {}".format(order_by, order),
values)
else:
self.cursor.execute("SELECT * FROM {}".format(self.tname) +
" WHERE {}".format(where_expr),
values)
return [self._wrap_selection(s) for s in self.cursor.fetchall()]
def _wrap_selection(self, selection):
if not self.keys:
self.keys = self._get_keys()
return {f: v for f, v in zip(self.keys, selection)}
def _get_keys(self):
self.cursor.execute("PRAGMA table_info({});".format(self.tname))
return [info[1] for info in self.cursor]
def _get_types(self):
self.cursor.execute("PRAGMA table_info({});".format(self.tname))
return {info[1]: info[2] for info in self.cursor}
def fit(self, data: List[Dict]) -> None:
if not self._check_if_table_exists():
self.keys = self.keys or list(set(k for d in data for k in d.keys()))
types = ('integer' if type(data[0][k]) == int else 'text' for k in self.keys)
self._create_table(self.keys, types)
elif not self.keys:
self.keys = self._get_keys()
self._insert_many(data)
def _create_table(self, keys, types):
if any(pk not in keys for pk in self.primary_keys):
raise ValueError("Primary keys must be from {}.".format(keys))
new_types = ("{} {} primary key".format(k, t) if k in self.primary_keys else
"{} {}".format(k, t)
for k, t in zip(keys, types))
self.cursor.execute("CREATE TABLE IF NOT EXISTS {} ({})"
.format(self.tname, ', '.join(new_types)))
log.info("Created table with keys {}.".format(self._get_types()))
def _insert_many(self, data):
to_insert = {}
to_update = {}
for kv in filter(None, data):
primary_values = tuple(kv[pk] for pk in self.primary_keys)
record = tuple(kv.get(k, self.unknown_value) for k in self.keys)
curr_record = self._get_record(primary_values)
if curr_record:
if primary_values in to_update:
curr_record = to_update[primary_values]
if curr_record != record:
to_update[primary_values] = record
else:
to_insert[primary_values] = record
if to_insert:
fformat = ','.join(['?'] * len(self.keys))
self.cursor.executemany("INSERT into {}".format(self.tname) +
" VALUES ({})".format(fformat),
to_insert.values())
if to_update:
for record in to_update.values():
self._update_one(record)
self.conn.commit()
def _get_record(self, primary_values):
ffields = ', '.join(self.keys) or '*'
where_expr = ' AND '.join("{} = '{}'".format(pk, v)
for pk, v in zip(self.primary_keys, primary_values))
fetched = self.cursor.execute("SELECT {} FROM {}".format(ffields, self.tname) +
" WHERE {}".format(where_expr)).fetchone()
if not fetched:
return None
return fetched
def _update_one(self, record):
set_expr = ', '.join("{} = '{}'".format(k, v)
for k, v in zip(self.keys, record)
if k not in self.primary_keys)
where_expr = ' AND '.join("{} = '{}'".format(k, v)
for k, v in zip(self.keys, record)
if k in self.primary_keys)
self.cursor.execute("UPDATE {}".format(self.tname) +
" SET {}".format(set_expr) +
" WHERE {}".format(where_expr))
def save(self):
pass
def load(self):
pass