######################################################################
#
# File: b2sdk/account_info/sqlite_account_info.py
#
# Copyright 2019 Backblaze Inc. All Rights Reserved.
#
# License https://www.backblaze.com/using_b2_code.html
#
######################################################################
import json
import logging
import os
import stat
import threading
from .exception import (CorruptAccountInfo, MissingAccountData)
from .upload_url_pool import UrlPoolAccountInfo
import sqlite3
logger = logging.getLogger(__name__)
B2_ACCOUNT_INFO_ENV_VAR = 'B2_ACCOUNT_INFO'
B2_ACCOUNT_INFO_DEFAULT_FILE = '~/.b2_account_info'
[docs]class SqliteAccountInfo(UrlPoolAccountInfo):
"""
Store account information in an `sqlite3 <https://www.sqlite.org>`_ database which is
used to manage concurrent access to the data.
The ``update_done`` table tracks the schema updates that have been
completed.
"""
[docs] def __init__(self, file_name=None, last_upgrade_to_run=None):
"""
If ``file_name`` argument is empty or ``None``, path from ``B2_ACCOUNT_INFO`` environment variable is used. If that is not available, a default of ``~/.b2_account_info`` is used.
:param str file_name: The sqlite file to use; overrides the default.
:param int last_upgrade_to_run: For testing only, override the auto-update on the db.
"""
self.thread_local = threading.local()
user_account_info_path = file_name or os.environ.get(
B2_ACCOUNT_INFO_ENV_VAR, B2_ACCOUNT_INFO_DEFAULT_FILE
)
self.filename = file_name or os.path.expanduser(user_account_info_path)
logger.debug('%s file path to use: %s', self.__class__.__name__, self.filename)
self._validate_database()
with self._get_connection() as conn:
self._create_tables(conn, last_upgrade_to_run)
super(SqliteAccountInfo, self).__init__()
def _validate_database(self, last_upgrade_to_run=None):
"""
Make sure that the database is openable. Removes the file if it's not.
"""
# If there is no file there, that's fine. It will get created when
# we connect.
if not os.path.exists(self.filename):
self._create_database(last_upgrade_to_run)
return
# If we can connect to the database, and do anything, then all is good.
try:
with self._connect() as conn:
self._create_tables(conn, last_upgrade_to_run)
return
except sqlite3.DatabaseError:
pass # fall through to next case
# If the file contains JSON with the right stuff in it, convert from
# the old representation.
try:
with open(self.filename, 'rb') as f:
data = json.loads(f.read().decode('utf-8'))
keys = [
'account_id', 'application_key', 'account_auth_token', 'api_url',
'download_url', 'minimum_part_size', 'realm'
]
if all(k in data for k in keys):
# remove the json file
os.unlink(self.filename)
# create a database
self._create_database(last_upgrade_to_run)
# add the data from the JSON file
with self._connect() as conn:
self._create_tables(conn, last_upgrade_to_run)
insert_statement = """
INSERT INTO account
(account_id, application_key, account_auth_token, api_url, download_url, minimum_part_size, realm)
values (?, ?, ?, ?, ?, ?, ?);
"""
conn.execute(insert_statement, tuple(data[k] for k in keys))
# all is happy now
return
except ValueError: # includes json.decoder.JSONDecodeError
pass
# Remove the corrupted file and create a new database
raise CorruptAccountInfo(self.filename)
def _get_connection(self):
"""
Connections to sqlite cannot be shared across threads.
"""
try:
return self.thread_local.connection
except AttributeError:
self.thread_local.connection = self._connect()
return self.thread_local.connection
def _connect(self):
return sqlite3.connect(self.filename, isolation_level='EXCLUSIVE')
def _create_database(self, last_upgrade_to_run):
"""
Make sure that the database is created and sets the file permissions.
This should be done before storing any sensitive data in it.
"""
# Create the tables in the database
conn = self._connect()
try:
with conn:
self._create_tables(conn, last_upgrade_to_run)
finally:
conn.close()
# Set the file permissions
os.chmod(self.filename, stat.S_IRUSR | stat.S_IWUSR)
def _create_tables(self, conn, last_upgrade_to_run):
conn.execute(
"""
CREATE TABLE IF NOT EXISTS
update_done (
update_number INT NOT NULL
);
"""
)
conn.execute(
"""
CREATE TABLE IF NOT EXISTS
account (
account_id TEXT NOT NULL,
application_key TEXT NOT NULL,
account_auth_token TEXT NOT NULL,
api_url TEXT NOT NULL,
download_url TEXT NOT NULL,
minimum_part_size INT NOT NULL,
realm TEXT NOT NULL
);
"""
)
conn.execute(
"""
CREATE TABLE IF NOT EXISTS
bucket (
bucket_name TEXT NOT NULL,
bucket_id TEXT NOT NULL
);
"""
)
# This table is not used any more. We may use it again
# someday if we save upload URLs across invocations of
# the command-line tool.
conn.execute(
"""
CREATE TABLE IF NOT EXISTS
bucket_upload_url (
bucket_id TEXT NOT NULL,
upload_url TEXT NOT NULL,
upload_auth_token TEXT NOT NULL
);
"""
)
# By default, we run all the upgrades
last_upgrade_to_run = 2 if last_upgrade_to_run is None else last_upgrade_to_run
# Add the 'allowed' column if it hasn't been yet.
if 1 <= last_upgrade_to_run:
self._ensure_update(1, 'ALTER TABLE account ADD COLUMN allowed TEXT;')
# Add the 'account_id_or_app_key_id' column if it hasn't been yet
if 2 <= last_upgrade_to_run:
self._ensure_update(2, 'ALTER TABLE account ADD COLUMN account_id_or_app_key_id TEXT;')
def _ensure_update(self, update_number, update_command):
"""
Run the update with the given number if it hasn't been done yet.
Does the update and stores the number as a single transaction,
so they will always be in sync.
"""
with self._get_connection() as conn:
conn.execute('BEGIN')
cursor = conn.execute(
'SELECT COUNT(*) AS count FROM update_done WHERE update_number = ?;',
(update_number,)
)
update_count = cursor.fetchone()[0]
assert update_count in [0, 1]
if update_count == 0:
conn.execute(update_command)
conn.execute(
'INSERT INTO update_done (update_number) VALUES (?);', (update_number,)
)
def clear(self):
"""
Remove all info about accounts and buckets.
"""
with self._get_connection() as conn:
conn.execute('DELETE FROM account;')
conn.execute('DELETE FROM bucket;')
conn.execute('DELETE FROM bucket_upload_url;')
def _set_auth_data(
self,
account_id,
auth_token,
api_url,
download_url,
minimum_part_size,
application_key,
realm,
allowed,
application_key_id,
):
assert self.allowed_is_valid(allowed)
with self._get_connection() as conn:
conn.execute('DELETE FROM account;')
conn.execute('DELETE FROM bucket;')
conn.execute('DELETE FROM bucket_upload_url;')
insert_statement = """
INSERT INTO account
(account_id, account_id_or_app_key_id, application_key, account_auth_token, api_url, download_url, minimum_part_size, realm, allowed)
values (?, ?, ?, ?, ?, ?, ?, ?, ?);
"""
conn.execute(
insert_statement, (
account_id,
application_key_id,
application_key,
auth_token,
api_url,
download_url,
minimum_part_size,
realm,
json.dumps(allowed),
)
)
def set_auth_data_with_schema_0_for_test(
self,
account_id,
auth_token,
api_url,
download_url,
minimum_part_size,
application_key,
realm,
):
"""
Set authentication data for tests.
:param str account_id: an account ID
:param str auth_token: an authentication token
:param str api_url: an API URL
:param str download_url: a download URL
:param int minimum_part_size: a minimum part size
:param str application_key: an application key
:param str realm: a realm to authorize account in
"""
with self._get_connection() as conn:
conn.execute('DELETE FROM account;')
conn.execute('DELETE FROM bucket;')
conn.execute('DELETE FROM bucket_upload_url;')
insert_statement = """
INSERT INTO account
(account_id, application_key, account_auth_token, api_url, download_url, minimum_part_size, realm)
values (?, ?, ?, ?, ?, ?, ?);
"""
conn.execute(
insert_statement, (
account_id,
application_key,
auth_token,
api_url,
download_url,
minimum_part_size,
realm,
)
)
def get_application_key(self):
return self._get_account_info_or_raise('application_key')
def get_account_id(self):
return self._get_account_info_or_raise('account_id')
def get_application_key_id(self):
"""
Return an application key ID.
The 'account_id_or_app_key_id' column was not in the original schema, so it may be NULL.
Nota bene - this is the only place where we are not renaming account_id_or_app_key_id to application_key_id
because it requires a column change.
application_key_id == account_id_or_app_key_id
:rtype: str
"""
result = self._get_account_info_or_raise('account_id_or_app_key_id')
if result is None:
return self.get_account_id()
else:
return result
def get_api_url(self):
return self._get_account_info_or_raise('api_url')
def get_account_auth_token(self):
return self._get_account_info_or_raise('account_auth_token')
def get_download_url(self):
return self._get_account_info_or_raise('download_url')
def get_realm(self):
return self._get_account_info_or_raise('realm')
def get_minimum_part_size(self):
return self._get_account_info_or_raise('minimum_part_size')
def get_allowed(self):
"""
Return 'allowed' dictionary info.
The 'allowed' column was not in the original schema, so it may be NULL.
:rtype: dict
"""
allowed_json = self._get_account_info_or_raise('allowed')
if allowed_json is None:
return self.DEFAULT_ALLOWED
else:
return json.loads(allowed_json)
def _get_account_info_or_raise(self, column_name):
try:
with self._get_connection() as conn:
cursor = conn.execute('SELECT %s FROM account;' % (column_name,))
value = cursor.fetchone()[0]
return value
except Exception as e:
logger.exception(
'_get_account_info_or_raise encountered a problem while trying to retrieve "%s"',
column_name
)
raise MissingAccountData(str(e))
def refresh_entire_bucket_name_cache(self, name_id_iterable):
with self._get_connection() as conn:
conn.execute('DELETE FROM bucket;')
for (bucket_name, bucket_id) in name_id_iterable:
conn.execute(
'INSERT INTO bucket (bucket_name, bucket_id) VALUES (?, ?);',
(bucket_name, bucket_id)
)
def save_bucket(self, bucket):
with self._get_connection() as conn:
conn.execute('DELETE FROM bucket WHERE bucket_id = ?;', (bucket.id_,))
conn.execute(
'INSERT INTO bucket (bucket_id, bucket_name) VALUES (?, ?);',
(bucket.id_, bucket.name)
)
def remove_bucket_name(self, bucket_name):
with self._get_connection() as conn:
conn.execute('DELETE FROM bucket WHERE bucket_name = ?;', (bucket_name,))
def get_bucket_id_or_none_from_bucket_name(self, bucket_name):
try:
with self._get_connection() as conn:
cursor = conn.execute(
'SELECT bucket_id FROM bucket WHERE bucket_name = ?;', (bucket_name,)
)
return cursor.fetchone()[0]
except TypeError: # TypeError: 'NoneType' object is unsubscriptable
return None
except sqlite3.Error:
return None