import os
import io
import base64
import json
import uuid
import asyncio
import subprocess
import time
from datetime import datetime, timedelta
from typing import Optional, Dict, Any

import requests
from fastapi import UploadFile, HTTPException, BackgroundTasks
from PIL import Image
from mutagen.mp3 import MP3
import numpy as np
import librosa

# ==============================
# Configuration
# ==============================

DEPLOYMENT_ID = "17079ce3-e7b6-4e3e-8530-f720b5efba0d"
TOKEN = "46889de3-7a9f-45bf-8310-af35019db8cc"
BASE_URL = os.getenv("RUNCOMFY_BASE_URL", "https://api.runcomfy.net/prod/v1")

HEADERS = {
    "Authorization": f"Bearer {TOKEN}",
    "Content-Type": "application/json",
}

FPS = 25

TEMP_DIR = "temp_uploads"
os.makedirs(TEMP_DIR, exist_ok=True)


ALLOWED_IMAGE_TYPES = {
    "image/jpeg",
    "image/png",
    "image/jpg",
}

ALLOWED_AUDIO_TYPES = {
    "audio/mpeg",
    "audio/mp3",
    "audio/wav",
    "audio/ogg",
}


# =====================================
# NODE IDs
# =====================================

NODE_IMAGE = "284"
NODE_AUDIO = "125"
NODE_AUDIO_CROP = "159"
NODE_FRAMES = "270"
NODE_PROMPT = "241"


# =====================================
# JOB STORE (in-memory)
# =====================================

jobs: Dict[str, Dict[str, Any]] = {}


# =====================================
# UTILITIES
# =====================================

def seconds_to_mmss(seconds: float):

    minutes = int(seconds) // 60
    seconds = int(seconds) % 60

    return f"{minutes}:{seconds:02d}"


def get_audio_duration(file_path: str) -> float:

    if file_path.endswith(".mp3"):
        try:
            audio = MP3(file_path)
            return float(audio.info.length)
        except:
            pass

    cmd = [
        "ffprobe",
        "-v",
        "quiet",
        "-print_format",
        "json",
        "-show_format",
        file_path,
    ]

    result = subprocess.run(cmd, capture_output=True, text=True)

    data = json.loads(result.stdout)

    return float(data["format"]["duration"])


def encode_image(path: str) -> str:

    img = Image.open(path)

    if img.mode != "RGB":
        img = img.convert("RGB")

    img.thumbnail((1024, 1024))

    buffer = io.BytesIO()

    img.save(buffer, format="JPEG", quality=95)

    b64 = base64.b64encode(buffer.getvalue()).decode()

    return f"data:image/jpeg;base64,{b64}"


def encode_audio(path: str, mime: str) -> str:

    with open(path, "rb") as f:
        b64 = base64.b64encode(f.read()).decode()

    return f"data:{mime};base64,{b64}"


def calculate_max_frames(duration: float) -> int:

    return max(1, int(duration * FPS))


def detect_emotion(audio_path: str) -> str:

    try:

        y, sr = librosa.load(audio_path, sr=None)

        energy = float(np.mean(librosa.feature.rms(y=y)))

        pitches, _ = librosa.piptrack(y=y, sr=sr)

        pitches = pitches[pitches > 0]

        pitch = float(np.mean(pitches)) if len(pitches) else 150

        if pitch > 200 and energy > 0.05:
            return "happy"

        if pitch < 130 and energy < 0.03:
            return "sad"

        return "neutral"

    except:
        return "neutral"


# =====================================
# RUNCOMFY API
# =====================================

def set_keep_warm_duration(
    seconds: int = 5,
    min_instances: int = 0,
    max_instances: int = 1,
):
    """
    Ensure deployment stays warm for X seconds after finishing requests.
    Safe to call multiple times.
    """

    url = f"{BASE_URL}/deployments/{DEPLOYMENT_ID}"

    payload = {
        "autoscaling": {
            "min_instances": min_instances,
            "max_instances": max_instances,
            "keep_warm_duration_in_seconds": seconds,
        }
    }

    try:

        response = requests.patch(
            url,
            headers=HEADERS,
            json=payload,
            timeout=20,
        )

        if response.status_code not in (200, 204):

            print("Warning: could not update keep warm:",
                  response.status_code,
                  response.text)

        else:

            print(
                f"Deployment autoscaling updated: "
                f"keep_warm_duration={seconds}s"
            )

    except Exception as e:

        print("Warning: keep warm update failed:", str(e))


def start_inference(image_uri, audio_uri, duration, emotion, prompt):

    set_keep_warm_duration(seconds=5)

    url = f"{BASE_URL}/deployments/{DEPLOYMENT_ID}/inference"

    max_frames = calculate_max_frames(duration)

    payload = {
        "overrides": {

            NODE_IMAGE: {
                "inputs": {
                    "image": image_uri
                }
            },

            NODE_AUDIO: {
                "inputs": {
                    "audio": audio_uri
                }
            },

            NODE_AUDIO_CROP: {
                "inputs": {
                    "start_time": "0:00",
                    "end_time": seconds_to_mmss(duration)
                }
            },

            NODE_FRAMES: {
                "inputs": {
                    "value": max_frames
                }
            },

            NODE_PROMPT: {
                "inputs": {
                    "positive_prompt":
                        f"{prompt}, {emotion}, ultra realistic, cinematic lighting"
                }
            }
        }
    }

    response = requests.post(url, headers=HEADERS, json=payload)

    if response.status_code != 200:
        raise RuntimeError(response.text)

    data = response.json()

    status_url = data.get("status_url")

    if not status_url:
        raise RuntimeError("RunComfy did not return status_url")

    return status_url


async def poll_with_progress(status_url, job):

    max_attempts = 1800

    for attempt in range(max_attempts):

        response = await asyncio.to_thread(
            requests.get,
            status_url,
            headers=HEADERS
        )

        if response.status_code != 200:
            raise RuntimeError(response.text)

        data = response.json()

        status = data.get("status", "")
        progress = data.get("progress")

        if progress is not None:

            job["progress"] = float(progress)

        elif status == "queued":

            job["progress"] = 0.01

        elif status == "running":

            # fallback تقديري
            job["progress"] = min(
                0.95,
                job.get("progress", 0.01) + 0.01
            )

        elif status in ("completed", "succeeded"):

            job["progress"] = 1.0

            return data

        elif status in ("failed", "error"):

            raise RuntimeError(
                data.get("error", "Generation failed")
            )

        await asyncio.sleep(2)

    raise RuntimeError("Timeout waiting for RunComfy")

async def process_job(job_id, image_path, audio_path,emotion,prompt):
    job = jobs[job_id]
    try:
        job["status"] = "processing"
        job["progress"] = 0.0
        duration = get_audio_duration(audio_path)
        # emotion = detect_emotion(audio_path)
        image_uri = encode_image(image_path)
        audio_uri = encode_audio(audio_path, "audio/mpeg")

        status_url = await asyncio.to_thread(
            start_inference,
            image_uri,
            audio_uri,
            duration,
            emotion,
            prompt
        )
        final_status = await  poll_with_progress(status_url, job)  

        request_id = final_status.get("request_id") or final_status.get("id")
        job["request_id"] = request_id
        if not request_id:
            raise RuntimeError("Could not retrieve request_id from final status")

        result_url = f"{BASE_URL}/deployments/{DEPLOYMENT_ID}/requests/{request_id}/result"
        resp = await asyncio.to_thread(
            requests.get,
            result_url,
            headers=HEADERS,
            timeout=30
        )
        resp.raise_for_status()
        result_data = resp.json()

        video_url = extract_video(result_data)
        if not video_url:
            video_url = result_data.get("result_url")
        if not video_url:
            raise RuntimeError("No video URL found in result")

        job["status"] = "completed"
        job["video_url"] = video_url
        job["progress"] = 1.0
        job["expires_at"] = (datetime.utcnow() + timedelta(hours=24)).isoformat()

    except Exception as e:
        job["status"] = "failed"
        job["error"] = str(e)
    finally:
        # تنظيف الملفات المؤقتة
        for path in (image_path, audio_path):
            if os.path.exists(path):
                os.remove(path)
                
def extract_video(result):
    outputs = result.get("outputs", {})
    for node in outputs.values():
        for key in ("gifs", "videos"):
            items = node.get(key, [])
            for item in items:
                url = item.get("url")
                if url:
                    return url
    return result.get("result_url")

# =====================================
# BACKGROUND JOB
# =====================================

# async def process_job(job_id, image_path, audio_path):

#     job = jobs[job_id]

#     try:

#         job["status"] = "processing"
#         job["progress"] = 0.0

#         duration = get_audio_duration(audio_path)
#         emotion = detect_emotion(audio_path)

#         image_uri = encode_image(image_path)
#         audio_uri = encode_audio(audio_path, "audio/mpeg")

#         status_url = await asyncio.to_thread(
#             start_inference,
#             image_uri,
#             audio_uri,
#             duration,
#             emotion
#         )

#         result = await poll_with_progress(status_url, job)

#         request_id = result.get("request_id") or result.get("id")

#         result_url = f"{BASE_URL}/deployments/{DEPLOYMENT_ID}/requests/{request_id}/result"

#         resp = await asyncio.to_thread(
#             requests.get,
#             result_url,
#             headers=HEADERS,
#             timeout=30
#         )

#         resp.raise_for_status()

#         result_data = resp.json()

#         video_url = extract_video(result_data)

#         if not video_url:
#             raise RuntimeError("No video returned")

#         job["status"] = "completed"
#         job["video_url"] = video_url
#         job["progress"] = 1.0

#         job["expires_at"] = (
#             datetime.utcnow() + timedelta(hours=24)
#         ).isoformat()

#     except Exception as e:

#         job["status"] = "failed"
#         job["error"] = str(e)

#     finally:

#         if os.path.exists(image_path):
#             os.remove(image_path)

#         if os.path.exists(audio_path):
#             os.remove(audio_path)


# =====================================
# PUBLIC API FUNCTIONS
# =====================================

async def create_video_job(
    # background_tasks: BackgroundTasks,
    image: UploadFile,
    audio: UploadFile,
    emotion: str,
    prompt: str
):

    if image.content_type not in ALLOWED_IMAGE_TYPES:
        raise HTTPException(400, "Invalid image type")

    if audio.content_type not in ALLOWED_AUDIO_TYPES:
        raise HTTPException(400, "Invalid audio type")

    job_id = str(uuid.uuid4())

    image_path = f"{TEMP_DIR}/{job_id}.jpg"
    audio_path = f"{TEMP_DIR}/{job_id}.mp3"

    with open(image_path, "wb") as f:
        f.write(await image.read())

    with open(audio_path, "wb") as f:
        f.write(await audio.read())

    jobs[job_id] = {
        "request_id": None, 
        "job_id": job_id,
        "status": "queued",
        "progress": 0.0,
        "video_url": None,
        "error": None,
    }

    asyncio.create_task(
        process_job(job_id, image_path, audio_path,emotion,prompt)
    )

    return jobs[job_id]


async def get_video_job(job_id: str):
    # أولًا نحصل على المهمة الداخلية
    job = jobs.get(job_id)

    if not job:
        raise HTTPException(404, "Job not found")


    # إذا لم يتم بدء المعالجة بعد أو لا يوجد request_id بعد
    request_id = job.get("request_id")
    if not request_id:
        # المهمة لم تبدأ بعد، نرجع الحالة الحالية فقط
        return {
            "job_id": job_id,
            "status": job["status"],
            "progress": job["progress"],
            "video_url": job.get("video_url"),
            "error": job.get("error"),
            "request_id": None
        }

    # نستخدم request_id للحصول على الحالة الفعلية من RunComfy
    status_url = f"{BASE_URL}/deployments/{DEPLOYMENT_ID}/requests/{request_id}/status"
    try:
        resp = requests.get(status_url, headers=HEADERS, timeout=30)
        resp.raise_for_status()
    except requests.RequestException as e:
        raise HTTPException(status_code=500, detail=f"Error fetching status: {e}")

    data = resp.json()

    # تحديث المهمة الداخلية بالقيم الجديدة
    job["status"] = data.get("status", job["status"])
    job["progress"] = data.get("progress", job["progress"])
    job["error"] = data.get("error", job.get("error"))
    job["result_url"] = data.get("result_url", job.get("video_url"))

    return {
        "job_id": job_id,
        "request_id": request_id,
        "status": job["status"],
        "progress": job["progress"],
        "video_url": job.get("video_url"),
        "error": job.get("error"),
        "result_url": job.get("result_url"),
    }

