Source code for b2sdk.account_info.sqlite_account_info

######################################################################
#
# File: b2sdk/account_info/sqlite_account_info.py
#
# Copyright 2019 Backblaze Inc. All Rights Reserved.
#
# License https://www.backblaze.com/using_b2_code.html
#
######################################################################

import json
import logging
import os
import stat
import threading

from .exception import (CorruptAccountInfo, MissingAccountData)
from .upload_url_pool import UrlPoolAccountInfo

import sqlite3

logger = logging.getLogger(__name__)

B2_ACCOUNT_INFO_ENV_VAR = 'B2_ACCOUNT_INFO'
B2_ACCOUNT_INFO_DEFAULT_FILE = '~/.b2_account_info'


[docs]class SqliteAccountInfo(UrlPoolAccountInfo): """ Store account information in an `sqlite3 <https://www.sqlite.org>`_ database which is used to manage concurrent access to the data. The ``update_done`` table tracks the schema updates that have been completed. """
[docs] def __init__(self, file_name=None, last_upgrade_to_run=None): """ If ``file_name`` argument is empty or ``None``, path from ``B2_ACCOUNT_INFO`` environment variable is used. If that is not available, a default of ``~/.b2_account_info`` is used. :param str file_name: The sqlite file to use; overrides the default. :param int last_upgrade_to_run: For testing only, override the auto-update on the db. """ self.thread_local = threading.local() user_account_info_path = file_name or os.environ.get( B2_ACCOUNT_INFO_ENV_VAR, B2_ACCOUNT_INFO_DEFAULT_FILE ) self.filename = file_name or os.path.expanduser(user_account_info_path) logger.debug('%s file path to use: %s', self.__class__.__name__, self.filename) self._validate_database() with self._get_connection() as conn: self._create_tables(conn, last_upgrade_to_run) super(SqliteAccountInfo, self).__init__()
def _validate_database(self, last_upgrade_to_run=None): """ Make sure that the database is openable. Removes the file if it's not. """ # If there is no file there, that's fine. It will get created when # we connect. if not os.path.exists(self.filename): self._create_database(last_upgrade_to_run) return # If we can connect to the database, and do anything, then all is good. try: with self._connect() as conn: self._create_tables(conn, last_upgrade_to_run) return except sqlite3.DatabaseError: pass # fall through to next case # If the file contains JSON with the right stuff in it, convert from # the old representation. try: with open(self.filename, 'rb') as f: data = json.loads(f.read().decode('utf-8')) keys = [ 'account_id', 'application_key', 'account_auth_token', 'api_url', 'download_url', 'minimum_part_size', 'realm' ] if all(k in data for k in keys): # remove the json file os.unlink(self.filename) # create a database self._create_database(last_upgrade_to_run) # add the data from the JSON file with self._connect() as conn: self._create_tables(conn, last_upgrade_to_run) insert_statement = """ INSERT INTO account (account_id, application_key, account_auth_token, api_url, download_url, minimum_part_size, realm) values (?, ?, ?, ?, ?, ?, ?); """ conn.execute(insert_statement, tuple(data[k] for k in keys)) # all is happy now return except ValueError: # includes json.decoder.JSONDecodeError pass # Remove the corrupted file and create a new database raise CorruptAccountInfo(self.filename) def _get_connection(self): """ Connections to sqlite cannot be shared across threads. """ try: return self.thread_local.connection except AttributeError: self.thread_local.connection = self._connect() return self.thread_local.connection def _connect(self): return sqlite3.connect(self.filename, isolation_level='EXCLUSIVE') def _create_database(self, last_upgrade_to_run): """ Make sure that the database is created and sets the file permissions. This should be done before storing any sensitive data in it. """ # Create the tables in the database conn = self._connect() try: with conn: self._create_tables(conn, last_upgrade_to_run) finally: conn.close() # Set the file permissions os.chmod(self.filename, stat.S_IRUSR | stat.S_IWUSR) def _create_tables(self, conn, last_upgrade_to_run): conn.execute( """ CREATE TABLE IF NOT EXISTS update_done ( update_number INT NOT NULL ); """ ) conn.execute( """ CREATE TABLE IF NOT EXISTS account ( account_id TEXT NOT NULL, application_key TEXT NOT NULL, account_auth_token TEXT NOT NULL, api_url TEXT NOT NULL, download_url TEXT NOT NULL, minimum_part_size INT NOT NULL, realm TEXT NOT NULL ); """ ) conn.execute( """ CREATE TABLE IF NOT EXISTS bucket ( bucket_name TEXT NOT NULL, bucket_id TEXT NOT NULL ); """ ) # This table is not used any more. We may use it again # someday if we save upload URLs across invocations of # the command-line tool. conn.execute( """ CREATE TABLE IF NOT EXISTS bucket_upload_url ( bucket_id TEXT NOT NULL, upload_url TEXT NOT NULL, upload_auth_token TEXT NOT NULL ); """ ) # By default, we run all the upgrades last_upgrade_to_run = 2 if last_upgrade_to_run is None else last_upgrade_to_run # Add the 'allowed' column if it hasn't been yet. if 1 <= last_upgrade_to_run: self._ensure_update(1, 'ALTER TABLE account ADD COLUMN allowed TEXT;') # Add the 'account_id_or_app_key_id' column if it hasn't been yet if 2 <= last_upgrade_to_run: self._ensure_update(2, 'ALTER TABLE account ADD COLUMN account_id_or_app_key_id TEXT;') def _ensure_update(self, update_number, update_command): """ Run the update with the given number if it hasn't been done yet. Does the update and stores the number as a single transaction, so they will always be in sync. """ with self._get_connection() as conn: conn.execute('BEGIN') cursor = conn.execute( 'SELECT COUNT(*) AS count FROM update_done WHERE update_number = ?;', (update_number,) ) update_count = cursor.fetchone()[0] assert update_count in [0, 1] if update_count == 0: conn.execute(update_command) conn.execute( 'INSERT INTO update_done (update_number) VALUES (?);', (update_number,) ) def clear(self): """ Remove all info about accounts and buckets. """ with self._get_connection() as conn: conn.execute('DELETE FROM account;') conn.execute('DELETE FROM bucket;') conn.execute('DELETE FROM bucket_upload_url;') def _set_auth_data( self, account_id, auth_token, api_url, download_url, minimum_part_size, application_key, realm, allowed, application_key_id, ): assert self.allowed_is_valid(allowed) with self._get_connection() as conn: conn.execute('DELETE FROM account;') conn.execute('DELETE FROM bucket;') conn.execute('DELETE FROM bucket_upload_url;') insert_statement = """ INSERT INTO account (account_id, account_id_or_app_key_id, application_key, account_auth_token, api_url, download_url, minimum_part_size, realm, allowed) values (?, ?, ?, ?, ?, ?, ?, ?, ?); """ conn.execute( insert_statement, ( account_id, application_key_id, application_key, auth_token, api_url, download_url, minimum_part_size, realm, json.dumps(allowed), ) ) def set_auth_data_with_schema_0_for_test( self, account_id, auth_token, api_url, download_url, minimum_part_size, application_key, realm, ): """ Set authentication data for tests. :param str account_id: an account ID :param str auth_token: an authentication token :param str api_url: an API URL :param str download_url: a download URL :param int minimum_part_size: a minimum part size :param str application_key: an application key :param str realm: a realm to authorize account in """ with self._get_connection() as conn: conn.execute('DELETE FROM account;') conn.execute('DELETE FROM bucket;') conn.execute('DELETE FROM bucket_upload_url;') insert_statement = """ INSERT INTO account (account_id, application_key, account_auth_token, api_url, download_url, minimum_part_size, realm) values (?, ?, ?, ?, ?, ?, ?); """ conn.execute( insert_statement, ( account_id, application_key, auth_token, api_url, download_url, minimum_part_size, realm, ) ) def get_application_key(self): return self._get_account_info_or_raise('application_key') def get_account_id(self): return self._get_account_info_or_raise('account_id') def get_application_key_id(self): """ Return an application key ID. The 'account_id_or_app_key_id' column was not in the original schema, so it may be NULL. Nota bene - this is the only place where we are not renaming account_id_or_app_key_id to application_key_id because it requires a column change. application_key_id == account_id_or_app_key_id :rtype: str """ result = self._get_account_info_or_raise('account_id_or_app_key_id') if result is None: return self.get_account_id() else: return result def get_api_url(self): return self._get_account_info_or_raise('api_url') def get_account_auth_token(self): return self._get_account_info_or_raise('account_auth_token') def get_download_url(self): return self._get_account_info_or_raise('download_url') def get_realm(self): return self._get_account_info_or_raise('realm') def get_minimum_part_size(self): return self._get_account_info_or_raise('minimum_part_size') def get_allowed(self): """ Return 'allowed' dictionary info. The 'allowed' column was not in the original schema, so it may be NULL. :rtype: dict """ allowed_json = self._get_account_info_or_raise('allowed') if allowed_json is None: return self.DEFAULT_ALLOWED else: return json.loads(allowed_json) def _get_account_info_or_raise(self, column_name): try: with self._get_connection() as conn: cursor = conn.execute('SELECT %s FROM account;' % (column_name,)) value = cursor.fetchone()[0] return value except Exception as e: logger.exception( '_get_account_info_or_raise encountered a problem while trying to retrieve "%s"', column_name ) raise MissingAccountData(str(e)) def refresh_entire_bucket_name_cache(self, name_id_iterable): with self._get_connection() as conn: conn.execute('DELETE FROM bucket;') for (bucket_name, bucket_id) in name_id_iterable: conn.execute( 'INSERT INTO bucket (bucket_name, bucket_id) VALUES (?, ?);', (bucket_name, bucket_id) ) def save_bucket(self, bucket): with self._get_connection() as conn: conn.execute('DELETE FROM bucket WHERE bucket_id = ?;', (bucket.id_,)) conn.execute( 'INSERT INTO bucket (bucket_id, bucket_name) VALUES (?, ?);', (bucket.id_, bucket.name) ) def remove_bucket_name(self, bucket_name): with self._get_connection() as conn: conn.execute('DELETE FROM bucket WHERE bucket_name = ?;', (bucket_name,)) def get_bucket_id_or_none_from_bucket_name(self, bucket_name): try: with self._get_connection() as conn: cursor = conn.execute( 'SELECT bucket_id FROM bucket WHERE bucket_name = ?;', (bucket_name,) ) return cursor.fetchone()[0] except TypeError: # TypeError: 'NoneType' object is unsubscriptable return None except sqlite3.Error: return None