import os
import edge_tts
from fastapi import FastAPI, File, UploadFile, Form, BackgroundTasks,HTTPException
from fastapi.middleware.cors import CORSMiddleware
from faster_whisper import WhisperModel
from dotenv import load_dotenv
from enum import Enum
from typing import List


# Load environment variables
load_dotenv()

# Import our custom services from the services package
from services.tts_service import tts_service
from services.transcription_service import transcription_service
from services.avatar_service import avatar_service
from services.infinite_video_generation_service import create_video_job, get_video_job

app = FastAPI(title="GLMA Optimized Media API")

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Global model state
MODEL_SIZE = os.getenv("WHISPER_MODEL_SIZE", "small")
model = None

@app.on_event("startup")
async def startup_event():
    global model
    print("\n" + "="*50)
    print(f"LOADING WHISPER MODEL ({MODEL_SIZE})... PLEASE WAIT")
    print("="*50)
    try:
        # Optimization: Using 4 CPU threads and 2 workers for faster CPU transcription
        model = WhisperModel(
            MODEL_SIZE, 
            device="cpu", 
            compute_type="int8",
            cpu_threads=4,
            num_workers=2
        )
        print("SUCCESS: Whisper model loaded and ready!")
    except Exception as e:
        print(f"ERROR loading model: {e}")
    print("="*50 + "\n")

@app.get("/")
async def root():
    return {"status": "online", "message": "API is running successfully."}

@app.post("/speak/")
async def text_to_speech(
    background_tasks: BackgroundTasks,
    text: str = Form(...), 
    voice: str = Form("Auto"),
    rate: str = Form("Auto"),
    pitch: str = Form("Auto"),
    volume: str = Form("Auto"),
    tashkeel: bool = Form(False)
):
    return await tts_service(background_tasks, text, voice, rate, pitch, volume, tashkeel)

@app.post("/transcribe/")
async def transcribe_audio(file: UploadFile = File(...)):
    return await transcription_service(file, model)

@app.post("/generate_avatar/")
async def generate_avatar(
    background_tasks: BackgroundTasks,
    text: str = Form(...), 
    voice: str = Form("Auto"),
    rate: str = Form("Auto"),
    pitch: str = Form("Auto"),
    volume: str = Form("Auto"),
    tashkeel: bool = Form(False),
    avatar_prompt: str = Form(None)
):
    return await avatar_service(background_tasks, text, voice, rate, pitch, volume, tashkeel, avatar_prompt)

@app.get("/available_voices/")
async def get_available_voices():
    voices = await edge_tts.list_voices()
    return sorted([v["ShortName"] for v in voices])

# confident,excited,engaging,motivated,calm
class Emotion(str, Enum):
    happy = "happy"
    sad = "sad"
    neutral = "neutral"
    angry = "angry"
    surprised = "surprised"
    calm = "calm"
    excited = "excited"
    anxious = "anxious"
    confident = "confident"
    frustrated = "frustrated"
    relaxed = "relaxed"
    curious = "curious"
    bored = "bored"
    motivated = "motivated"
    disappointed = "disappointed"
    hopeful = "hopeful"
    tired = "tired"
    engaging = "engaging"

@app.post("/video/create/")
async def create_video(
    image: UploadFile = File(...),
    audio: UploadFile = File(...),
    emotions: str = Form(...),
    prompt: str = Form(...)
):
    emotions_list = []
    for e in emotions.split(","):
        e = e.strip()
        if e not in Emotion.__members__:
            raise HTTPException(400, f"Invalid emotion: {e}")
        emotions_list.append(Emotion[e])
        
    emotions_str = ", ".join([e.value for e in emotions_list])
    return await create_video_job(image, audio, emotions_str, prompt)

@app.get("/video/status/{job_id}")
async def video_status(job_id: str):
    return await get_video_job(job_id)


if __name__ == "__main__":
    import uvicorn
    host = os.getenv("APP_HOST", "0.0.0.0")
    port = int(os.getenv("APP_PORT", 8080))
    uvicorn.run(app, host=host, port=port)