clea context manager for connections, type hunts docstrings and easy plug and play postgresSQL swap
This commit is contained in:
315
cogs/music/db_manager.py
Normal file
315
cogs/music/db_manager.py
Normal file
@@ -0,0 +1,315 @@
|
||||
"""
|
||||
Database Manager for Groovy-Zilean
|
||||
Centralizes all database operations and provides a clean interface.
|
||||
Makes future PostgreSQL migration much easier.
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, List, Tuple, Any
|
||||
import config
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
"""Manages database connections and operations"""
|
||||
|
||||
def __init__(self):
|
||||
self.db_path = config.get_db_path()
|
||||
|
||||
@contextmanager
|
||||
def get_connection(self):
|
||||
"""
|
||||
Context manager for database connections.
|
||||
Automatically handles commit/rollback and closing.
|
||||
|
||||
Usage:
|
||||
with db.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(...)
|
||||
"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
try:
|
||||
yield conn
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def initialize_tables(self):
|
||||
"""Create database tables if they don't exist"""
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Create servers table
|
||||
cursor.execute('''CREATE TABLE IF NOT EXISTS servers (
|
||||
server_id TEXT PRIMARY KEY,
|
||||
is_playing INTEGER DEFAULT 0,
|
||||
song_name TEXT,
|
||||
song_url TEXT,
|
||||
song_thumbnail TEXT,
|
||||
loop_mode TEXT DEFAULT 'off',
|
||||
volume INTEGER DEFAULT 100,
|
||||
effect TEXT DEFAULT 'none',
|
||||
song_start_time REAL DEFAULT 0,
|
||||
song_duration INTEGER DEFAULT 0
|
||||
);''')
|
||||
|
||||
# Set all to not playing on startup
|
||||
cursor.execute("UPDATE servers SET is_playing = 0;")
|
||||
|
||||
# Migrations for existing databases - add columns if missing
|
||||
migrations = [
|
||||
("loop_mode", "TEXT DEFAULT 'off'"),
|
||||
("volume", "INTEGER DEFAULT 100"),
|
||||
("effect", "TEXT DEFAULT 'none'"),
|
||||
("song_start_time", "REAL DEFAULT 0"),
|
||||
("song_duration", "INTEGER DEFAULT 0"),
|
||||
("song_thumbnail", "TEXT DEFAULT ''"),
|
||||
("song_url", "TEXT DEFAULT ''")
|
||||
]
|
||||
|
||||
for col_name, col_type in migrations:
|
||||
try:
|
||||
cursor.execute(f"ALTER TABLE servers ADD COLUMN {col_name} {col_type};")
|
||||
except sqlite3.OperationalError:
|
||||
# Column already exists, skip
|
||||
pass
|
||||
|
||||
# Create songs/queue table
|
||||
cursor.execute('''CREATE TABLE IF NOT EXISTS songs (
|
||||
server_id TEXT NOT NULL,
|
||||
song_link TEXT,
|
||||
queued_by TEXT,
|
||||
position INTEGER NOT NULL,
|
||||
title TEXT,
|
||||
thumbnail TEXT,
|
||||
duration INTEGER,
|
||||
PRIMARY KEY (position),
|
||||
FOREIGN KEY (server_id) REFERENCES servers(server_id)
|
||||
);''')
|
||||
|
||||
# Clear all songs on startup
|
||||
cursor.execute("DELETE FROM songs;")
|
||||
|
||||
# ===================================
|
||||
# Server Operations
|
||||
# ===================================
|
||||
|
||||
def ensure_server_exists(self, server_id: str) -> None:
|
||||
"""Add server to database if it doesn't exist"""
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('SELECT COUNT(*) FROM servers WHERE server_id = ?', (server_id,))
|
||||
if cursor.fetchone()[0] == 0:
|
||||
cursor.execute('''INSERT INTO servers (server_id, loop_mode, volume, effect, song_thumbnail, song_url)
|
||||
VALUES (?, 'off', 100, 'none', '', '')''', (server_id,))
|
||||
|
||||
def set_server_playing(self, server_id: str, playing: bool) -> None:
|
||||
"""Update server playing status"""
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
self.ensure_server_exists(server_id)
|
||||
val = 1 if playing else 0
|
||||
cursor.execute("UPDATE servers SET is_playing = ? WHERE server_id = ?", (val, server_id))
|
||||
|
||||
def is_server_playing(self, server_id: str) -> bool:
|
||||
"""Check if server is currently playing"""
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
self.ensure_server_exists(server_id)
|
||||
cursor.execute("SELECT is_playing FROM servers WHERE server_id = ?", (server_id,))
|
||||
res = cursor.fetchone()
|
||||
return True if res and res[0] == 1 else False
|
||||
|
||||
def set_current_song(self, server_id: str, title: str, url: str, thumbnail: str = "", duration: int = 0, start_time: float = 0) -> None:
|
||||
"""Update currently playing song information"""
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
# Ensure duration is an integer
|
||||
try:
|
||||
duration = int(duration)
|
||||
except:
|
||||
duration = 0
|
||||
|
||||
cursor.execute(''' UPDATE servers
|
||||
SET song_name = ?, song_url = ?, song_thumbnail = ?, song_start_time = ?, song_duration = ?
|
||||
WHERE server_id = ?''',
|
||||
(title, url, thumbnail, start_time, duration, server_id))
|
||||
|
||||
def get_current_song(self, server_id: str) -> dict:
|
||||
"""Get current song info"""
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(''' SELECT song_name, song_thumbnail, song_url FROM servers WHERE server_id = ? LIMIT 1;''', (server_id,))
|
||||
result = cursor.fetchone()
|
||||
|
||||
if result:
|
||||
return {'title': result[0], 'thumbnail': result[1], 'url': result[2]}
|
||||
return {'title': "Nothing", 'thumbnail': None, 'url': ''}
|
||||
|
||||
def get_current_progress(self, server_id: str) -> Tuple[int, int, float]:
|
||||
"""Get playback progress (elapsed, duration, percentage)"""
|
||||
import time
|
||||
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''SELECT song_start_time, song_duration, is_playing FROM servers WHERE server_id = ? LIMIT 1;''', (server_id,))
|
||||
result = cursor.fetchone()
|
||||
|
||||
if not result or result[2] == 0:
|
||||
return 0, 0, 0.0
|
||||
|
||||
start_time, duration, _ = result
|
||||
|
||||
if duration is None or duration == 0:
|
||||
return 0, 0, 0.0
|
||||
|
||||
elapsed = int(time.time() - start_time)
|
||||
elapsed = min(elapsed, duration)
|
||||
percentage = (elapsed / duration) * 100 if duration > 0 else 0
|
||||
|
||||
return elapsed, duration, percentage
|
||||
|
||||
# ===================================
|
||||
# Queue Operations
|
||||
# ===================================
|
||||
|
||||
def add_song(self, server_id: str, song_link: str, queued_by: str, title: str, thumbnail: str = "", duration: int = 0, position: Optional[int] = None) -> int:
|
||||
"""Add song to queue, returns position"""
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
self.ensure_server_exists(server_id)
|
||||
|
||||
if position is None:
|
||||
# Add to end
|
||||
cursor.execute("SELECT MAX(position) FROM songs WHERE server_id = ?", (server_id,))
|
||||
max_pos = cursor.fetchone()[0]
|
||||
position = (max_pos + 1) if max_pos is not None else 0
|
||||
else:
|
||||
# Insert at specific position (shift others down)
|
||||
cursor.execute("UPDATE songs SET position = position + 1 WHERE server_id = ? AND position >= ?",
|
||||
(server_id, position))
|
||||
|
||||
cursor.execute("""INSERT INTO songs VALUES (?, ?, ?, ?, ?, ?, ?)""",
|
||||
(server_id, song_link, queued_by, position, title, thumbnail, duration))
|
||||
|
||||
return position
|
||||
|
||||
def get_next_song(self, server_id: str) -> Optional[Tuple]:
|
||||
"""Get the next song in queue (doesn't remove it)"""
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''SELECT * FROM songs WHERE server_id = ? ORDER BY position LIMIT 1;''', (server_id,))
|
||||
return cursor.fetchone()
|
||||
|
||||
def remove_song(self, server_id: str, position: int) -> None:
|
||||
"""Remove song at position from queue"""
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''DELETE FROM songs WHERE server_id = ? AND position = ?''', (server_id, position))
|
||||
|
||||
def get_queue(self, server_id: str, limit: int = 10) -> Tuple[int, List[Tuple]]:
|
||||
"""Get songs in queue (returns max_position, list of songs)"""
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
self.ensure_server_exists(server_id)
|
||||
|
||||
cursor.execute("SELECT title, duration, queued_by FROM songs WHERE server_id = ? ORDER BY position LIMIT ?",
|
||||
(server_id, limit))
|
||||
songs = cursor.fetchall()
|
||||
|
||||
cursor.execute("SELECT MAX(position) FROM songs WHERE server_id = ?", (server_id,))
|
||||
max_pos = cursor.fetchone()[0]
|
||||
max_pos = max_pos if max_pos is not None else -1
|
||||
|
||||
return max_pos, songs
|
||||
|
||||
def clear_queue(self, server_id: str) -> None:
|
||||
"""Clear all songs from queue"""
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
self.ensure_server_exists(server_id)
|
||||
cursor.execute("DELETE FROM songs WHERE server_id = ?", (server_id,))
|
||||
|
||||
def shuffle_queue(self, server_id: str) -> bool:
|
||||
"""Shuffle the queue randomly, returns success"""
|
||||
import random
|
||||
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
self.ensure_server_exists(server_id)
|
||||
|
||||
cursor.execute("SELECT position, song_link, queued_by, title, thumbnail, duration FROM songs WHERE server_id = ? ORDER BY position",
|
||||
(server_id,))
|
||||
songs = cursor.fetchall()
|
||||
|
||||
if len(songs) <= 1:
|
||||
return False
|
||||
|
||||
random.shuffle(songs)
|
||||
cursor.execute("DELETE FROM songs WHERE server_id = ?", (server_id,))
|
||||
|
||||
for i, s in enumerate(songs):
|
||||
cursor.execute("INSERT INTO songs VALUES (?, ?, ?, ?, ?, ?, ?)",
|
||||
(server_id, s[1], s[2], i, s[3], s[4], s[5]))
|
||||
|
||||
return True
|
||||
|
||||
# ===================================
|
||||
# Settings Operations
|
||||
# ===================================
|
||||
|
||||
def get_loop_mode(self, server_id: str) -> str:
|
||||
"""Get loop mode: 'off', 'song', or 'queue'"""
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
self.ensure_server_exists(server_id)
|
||||
cursor.execute("SELECT loop_mode FROM servers WHERE server_id = ?", (server_id,))
|
||||
res = cursor.fetchone()
|
||||
return res[0] if res else 'off'
|
||||
|
||||
def set_loop_mode(self, server_id: str, mode: str) -> None:
|
||||
"""Set loop mode: 'off', 'song', or 'queue'"""
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
self.ensure_server_exists(server_id)
|
||||
cursor.execute("UPDATE servers SET loop_mode = ? WHERE server_id = ?", (mode, server_id))
|
||||
|
||||
def get_volume(self, server_id: str) -> int:
|
||||
"""Get volume (0-200)"""
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
self.ensure_server_exists(server_id)
|
||||
cursor.execute("SELECT volume FROM servers WHERE server_id = ?", (server_id,))
|
||||
res = cursor.fetchone()
|
||||
return res[0] if res else 100
|
||||
|
||||
def set_volume(self, server_id: str, volume: int) -> int:
|
||||
"""Set volume (0-200), returns the set volume"""
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
self.ensure_server_exists(server_id)
|
||||
cursor.execute("UPDATE servers SET volume = ? WHERE server_id = ?", (volume, server_id))
|
||||
return volume
|
||||
|
||||
def get_effect(self, server_id: str) -> str:
|
||||
"""Get current audio effect"""
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
self.ensure_server_exists(server_id)
|
||||
cursor.execute("SELECT effect FROM servers WHERE server_id = ?", (server_id,))
|
||||
res = cursor.fetchone()
|
||||
return res[0] if res else 'none'
|
||||
|
||||
def set_effect(self, server_id: str, effect: str) -> None:
|
||||
"""Set audio effect"""
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
self.ensure_server_exists(server_id)
|
||||
cursor.execute("UPDATE servers SET effect = ? WHERE server_id = ?", (effect, server_id))
|
||||
|
||||
|
||||
# Global instance
|
||||
db = DatabaseManager()
|
||||
Reference in New Issue
Block a user