"""
pickleDB - https://patx.github.io/pickledb
Harrison Erd - https://harrisonerd.com/
Licensed - BSD 3 Clause (see LICENSE)
"""

import asyncio
import os
from typing import Any
import uuid

import orjson
import aiofiles
try:
    import sqlite3
    import aiosqlite
    sqlite_enable = True
except ImportError:
    sqlite_enable = False


MISSING = object()


def in_async() -> bool:
    """Return True if we're currently running inside an event loop."""
    try:
        asyncio.get_running_loop()
        return True
    except RuntimeError:
        return False


def dualmethod(func):
    """
    Decorator that lets an async method be called in both sync and async code.

    - In async code: returns the coroutine (you must `await` it).
    - In sync code: runs the coroutine with asyncio.run() and returns the result.
    """
    def wrapper(self, *args, **kwargs):
        coro = func(self, *args, **kwargs)
        if in_async():
            return coro
        return asyncio.run(coro)
    return wrapper


class PickleDB:
    """
    A unified async/sync key-value store using orjson + aiofiles.

    All data is kept in-memory in `self.db` and serialized to disk as a single
    orjson-encoded file at `self.location`.
    """

    def __init__(self, location: str):
        self.location = os.path.expanduser(location)
        self.db: dict[str, Any] = {}
        self._lock = asyncio.Lock()

    def __enter__(self):
        self.load()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if exc_type is None:
            self.save()

    async def __aenter__(self):
        await self.load()
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        if exc_type is None:
            await self.save()

    @dualmethod
    async def load(self) -> bool:
        """
        Load JSON database from disk into memory.

        Returns `self` to allow chaining:
            db = PickleDB("Example.json").load()
        """
        if os.path.exists(self.location) and os.path.getsize(self.location) > 0:
            async with aiofiles.open(self.location, "rb") as f:
                data = await f.read()
            new_db = orjson.loads(data)
        else:
            new_db = {}

        async with self._lock:
            self.db = new_db
        return self

    @dualmethod
    async def save(self) -> bool:
        """
        Atomically save database to disk.

        Writes to `<location>.tmp` and then os.replace() over the original file.
        Returns True on success.
        """
        temp = f"{self.location}.tmp"
        async with self._lock:
            async with aiofiles.open(temp, "wb") as f:
                await f.write(orjson.dumps(self.db))
            await asyncio.to_thread(os.replace, temp, self.location)
        return True

    @dualmethod
    async def set(self, key, value) -> bool:
        """Set a key-value pair. Always returns True."""
        async with self._lock:
            self.db[str(key)] = value
        return True

    @dualmethod
    async def get(self, key, default=None):
        """Get a key's value, or `default` if missing."""
        async with self._lock:
            return self.db.get(str(key), default)

    @dualmethod
    async def remove(self, key) -> bool:
        """Remove a key-value pair. Returns True if it existed, False otherwise."""
        async with self._lock:
            return self.db.pop(str(key), None) is not None

    @dualmethod
    async def all(self):
        """Return a list of all keys."""
        async with self._lock:
            return list(self.db.keys())

    @dualmethod
    async def purge(self) -> bool:
        """Remove all key-value pairs from the database. Always returns True."""
        async with self._lock:
            self.db.clear()
        return True

if sqlite_enable:
    class PickleDBSQLite:
        """
        A unified async/sync key-value store backed by SQLite.

        Each key is stored as a row:

            CREATE TABLE kv (
                key   TEXT PRIMARY KEY,
                value BLOB NOT NULL
            )

        Values are stored as JSON-encoded bytes via orjson.
        """

        def __init__(
            self,
            sqlite_path: str = "pickledb.sqlite3",
            table_name: str = "kv",
        ) -> None:
            self.sqlite_path = sqlite_path
            self.table_name = table_name

            self._conn = sqlite3.connect(self.sqlite_path, check_same_thread=False)
            self._conn.row_factory = sqlite3.Row

            self._conn.execute(
                f"""
                CREATE TABLE IF NOT EXISTS {self.table_name} (
                    key   TEXT PRIMARY KEY,
                    value BLOB NOT NULL
                )
                """
            )
            self._conn.commit()

        def _dumps(self, value: Any) -> bytes:
            """Serialize a Python object to orjson-encoded bytes."""
            return orjson.dumps(value)

        def _loads(self, data: bytes) -> Any:
            """Deserialize orjson-encoded bytes back into a Python object."""
            return orjson.loads(data)

        def set(self, key: str | None, value: Any) -> str:
            """
            Set a key-value pair.

            If key is None, generate a new random UUID key and return it.

            In async code, returns a coroutine you must `await`.
            In sync code, returns the key string directly.
            """
            if in_async():
                async def _aset() -> str:
                    async with aiosqlite.connect(self.sqlite_path) as db:
                        db.row_factory = sqlite3.Row
                        payload = self._dumps(value)

                        if key is None:
                            new_key = str(uuid.uuid4())
                            await db.execute(
                                f"INSERT INTO {self.table_name} (key, value) VALUES (?, ?)",
                                (new_key, payload),
                            )
                            await db.commit()
                            return new_key

                        await db.execute(
                            f"""
                            INSERT INTO {self.table_name} (key, value)
                            VALUES (?, ?)
                            ON CONFLICT(key) DO UPDATE SET value=excluded.value
                            """,
                            (str(key), payload),
                        )
                        await db.commit()
                        return str(key)

                return _aset()

            payload = self._dumps(value)
            if key is None:
                new_key = str(uuid.uuid4())
                self._conn.execute(
                    f"INSERT INTO {self.table_name} (key, value) VALUES (?, ?)",
                    (new_key, payload),
                )
                self._conn.commit()
                return new_key

            self._conn.execute(
                f"""
                INSERT INTO {self.table_name} (key, value)
                VALUES (?, ?)
                ON CONFLICT(key) DO UPDATE SET value=excluded.value
                """,
                (str(key), payload),
            )
            self._conn.commit()
            return str(key)

        def get(self, key: str, default: Any = MISSING) -> Any:
            """
            Get the value for a key.

            If the key does not exist:
                - If default is MISSING, raises KeyError.
                - Otherwise returns default.

            In async code, returns a coroutine you must `await`.
            In sync code, returns the value directly.
            """
            if in_async():
                async def _aget() -> Any:
                    async with aiosqlite.connect(self.sqlite_path) as db:
                        db.row_factory = sqlite3.Row
                        cursor = await db.execute(
                            f"SELECT value FROM {self.table_name} WHERE key = ?",
                            (str(key),),
                        )
                        row = await cursor.fetchone()
                        await cursor.close()

                        if row is None:
                            if default is MISSING:
                                raise KeyError(key)
                            return default

                        return self._loads(row["value"])

                return _aget()

            cursor = self._conn.execute(
                f"SELECT value FROM {self.table_name} WHERE key = ?",
                (str(key),),
            )
            row = cursor.fetchone()
            if row is None:
                if default is MISSING:
                    raise KeyError(key)
                return default
            return self._loads(row["value"])

        def remove(self, key: str) -> bool:
            """
            Remove a key-value pair.

            Returns True if a row was deleted, False otherwise.

            In async code, returns a coroutine you must `await`.
            In sync code, returns a bool directly.
            """
            if in_async():
                async def _aremove() -> bool:
                    async with aiosqlite.connect(self.sqlite_path) as db:
                        cursor = await db.execute(
                            f"DELETE FROM {self.table_name} WHERE key = ?",
                            (str(key),),
                        )
                        await db.commit()
                        return cursor.rowcount > 0

                return _aremove()

            cursor = self._conn.execute(
                f"DELETE FROM {self.table_name} WHERE key = ?",
                (str(key),),
            )
            self._conn.commit()
            return cursor.rowcount > 0

        def all(self) -> list[str]:
            """
            Return a list of all keys in the database.

            In async code, returns a coroutine you must `await`.
            In sync code, returns the list directly.
            """
            if in_async():
                async def _aall() -> list[str]:
                    async with aiosqlite.connect(self.sqlite_path) as db:
                        db.row_factory = sqlite3.Row
                        cursor = await db.execute(
                            f"SELECT key FROM {self.table_name} ORDER BY key"
                        )
                        rows = await cursor.fetchall()
                        await cursor.close()
                        return [row["key"] for row in rows]

                return _aall()

            cursor = self._conn.execute(
                f"SELECT key FROM {self.table_name} ORDER BY key"
            )
            return [row["key"] for row in cursor.fetchall()]

        def purge(self) -> bool:
            """
            Remove all key-value pairs from the database.

            Always returns True.

            In async code, returns a coroutine you must `await`.
            In sync code, returns True directly.
            """
            if in_async():
                async def _apurge() -> bool:
                    async with aiosqlite.connect(self.sqlite_path) as db:
                        await db.execute(f"DELETE FROM {self.table_name}")
                        await db.commit()
                        return True

                return _apurge()

            self._conn.execute(f"DELETE FROM {self.table_name}")
            self._conn.commit()
            return True

        def close(self) -> None:
            """
            Close the underlying sync SQLite connection.

            In async code, returns a coroutine you must `await`.
            In sync code, closes immediately.
            """
            if in_async():
                async def _aclose() -> None:
                    self._conn.close()
                return _aclose()

            self._conn.close()


else:
    class PickleDBSQLite:
        """
        This class is only usable if `aiosqlite` is installed, e.g.:
            pip install "pickledb[sqlite]"
        """

        def __init__(self, *args: Any, **kwargs: Any) -> None:
            raise RuntimeError(
                "PickleDBSQLite requires `aiosqlite`. "
                "Install it via `pip install \"pickledb[sqlite]\"`."
            )

