#!/usr/bin/env python3
"""
Show ACTUAL failed LLM data - the real JSON, bad LaTeX, raw output
No bullshit reports - just the data that broke
"""

import subprocess
import sys
import json
import re
import signal

# Ignore broken pipe errors when piping to head/less
signal.signal(signal.SIGPIPE, signal.SIG_DFL)

import os

# Get klogs from same directory as this script
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
KLOGS_PATH = os.path.join(SCRIPT_DIR, "klogs")

def run_klogs_json(args):
    """Run klogs and get JSON output"""
    cmd = [KLOGS_PATH] + args + ["--json"]
    try:
        result = subprocess.run(cmd, capture_output=True, text=True, check=True, timeout=30)
        logs = []
        
        # Strip ANSI color codes from output
        ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])')
        clean_output = ansi_escape.sub('', result.stdout)
        
        # klogs outputs human-readable text first, then JSON objects
        # We need to find where JSON starts and parse complete objects
        lines = clean_output.split('\n')
        in_json = False
        json_buffer = []
        
        for line in lines:
            # Detect start of JSON object
            if line.startswith('{'):
                in_json = True
                json_buffer = [line]
            elif in_json:
                json_buffer.append(line)
                # Detect end of JSON object (line that's just "}")
                if line.strip() == '}':
                    try:
                        json_text = '\n'.join(json_buffer)
                        log = json.loads(json_text)
                        logs.append(log)
                    except json.JSONDecodeError:
                        pass  # Skip invalid JSON
                    in_json = False
                    json_buffer = []
        
        return logs
    except subprocess.TimeoutExpired:
        print(f"⚠️  Query timed out (too much data). Try shorter time range like --time 1h", file=sys.stderr)
        return []
    except subprocess.CalledProcessError as e:
        print(f"Error running klogs: {e.stderr}", file=sys.stderr)
        print(f"Make sure port-forward is running:", file=sys.stderr)
        print(f"  kubectl port-forward -n observability svc/loki-gateway 3100:80 --context=dev", file=sys.stderr)
        return []
    except Exception as e:
        print(f"Error: {e}", file=sys.stderr)
        return []


def show_raw_unmarshal_failures(time_range="30m"):
    """Show the actual JSON that failed to unmarshal"""
    print("=" * 100)
    print("🔥 ACTUAL JSON THAT FAILED TO UNMARSHAL")
    print("=" * 100)
    print()
    
    # Get error logs with rawJSON field
    logs = run_klogs_json(["search", "--severity", "ERROR", "--message", "failed to unmarshal", "--time", time_range, "--limit", "10"])
    
    if not logs:
        print("No unmarshal failures found")
        return
    
    for i, log in enumerate(logs, 1):
        print(f"\n{'='*100}")
        print(f"FAILURE #{i}")
        print(f"{'='*100}")
        
        # Extract from message (logs are structured)
        msg = log.get('message', '')
        
        # Look for rawJSON in the log fields
        if 'rawJSON' in str(log):
            # Try to extract rawJSON from various possible locations
            for key in ['rawJSON', 'raw_json', 'rawOutput']:
                if key in log:
                    print(f"\n📄 RAW JSON THAT FAILED:")
                    print("-" * 100)
                    raw = log[key]
                    # Pretty print if it's valid JSON
                    try:
                        parsed = json.loads(raw)
                        print(json.dumps(parsed, indent=2))
                    except:
                        print(raw)
                    break
        
        # Show the error
        if 'error' in log:
            print(f"\n❌ ERROR:")
            print("-" * 100)
            print(log['error'])
        
        # Show context
        if 'nodeID' in log:
            print(f"\n📍 NODE ID: {log['nodeID']}")
        if 'lessonPlanID' in log:
            print(f"📍 LESSON PLAN ID: {log['lessonPlanID']}")
        if 'traceID' in log:
            print(f"🔗 TRACE ID: {log['traceID']}")
        
        print()


def show_raw_latex_failures(time_range="30m"):
    """Show LaTeX errors and corrections"""
    print("\n" + "=" * 100)
    print("📐 LATEX ERRORS AND CORRECTIONS")
    print("=" * 100)
    print()
    
    # First show correction failures (critical)
    print("🔴 CORRECTION FAILURES (LaTeX still broken after correction):")
    print("-" * 100)
    failure_logs = run_klogs_json(["search", "--message", "corrected field still has malformed latex", "--time", time_range, "--limit", "5"])
    
    if failure_logs:
        for i, log in enumerate(failure_logs, 1):
            print(f"\nFAILURE #{i}")
            error_data = log.get('error', {})
            if isinstance(error_data, dict):
                context = error_data.get('context', {})
                if 'correctedField' in context:
                    print(f"Bad LaTeX: {context['correctedField'][:200]}")
            if 'traceID' in log:
                print(f"Trace: {log['traceID']}")
    else:
        print("✅ No correction failures (all corrections succeeded!)")
    
    # Show initial errors detected
    print("\n\n🟡 INITIAL LATEX ERRORS DETECTED:")
    print("-" * 100)
    error_logs = run_klogs_json(["search", "--message", "latex error before", "--time", time_range, "--limit", "5"])
    
    if not error_logs:
        print("✅ No LaTeX errors detected")
        return
    
    print(f"Found {len(error_logs)} recent LaTeX errors that were corrected")
    
    for i, log in enumerate(error_logs, 1):
        print(f"\n{'='*100}")
        print(f"LATEX FAILURE #{i}")
        print(f"{'='*100}")
        
        # Try to extract the corrected field from error context
        error_data = log.get('error', {})
        if isinstance(error_data, dict):
            context = error_data.get('context', {})
            if 'correctedField' in context:
                print(f"\n🔴 BAD LATEX (even after correction attempt):")
                print("-" * 100)
                print(context['correctedField'])
        
        # Show trace info
        if 'traceID' in log:
            print(f"\n🔗 TRACE ID: {log['traceID']}")
        if 'nodeID' in log:
            print(f"📍 NODE ID: {log['nodeID']}")
        
        print()


def show_llm_prompt_responses(time_range="30m", limit=5):
    """Show actual LLM prompts and responses"""
    print("\n" + "=" * 100)
    print(f"💬 ACTUAL LLM PROMPTS AND RESPONSES (last {limit})")
    print("=" * 100)
    print()
    
    # Get prompt-input-and-output logs
    logs = run_klogs_json(["search", "--message", "prompt-input-and-output", "--time", time_range, "--limit", str(limit)])
    
    if not logs:
        print("No LLM operation logs found")
        return
    
    for i, log in enumerate(logs, 1):
        print(f"\n{'='*100}")
        print(f"LLM CALL #{i}")
        print(f"{'='*100}")
        
        # Extract metadata from the log message
        msg = log.get('message', '')
        
        # Show what we can extract
        print(f"📍 Model: {log.get('model', 'N/A')}")
        print(f"📍 Provider: {log.get('provider', 'N/A')}")
        print(f"📍 Node Type: {log.get('nodeType', 'N/A')}")
        print(f"📍 Prompt Type: {log.get('promptType', 'N/A')}")
        
        if 'traceID' in log:
            print(f"🔗 Trace ID: {log['traceID']}")
        
        # Try to show prompts if available
        for key in ['system', 'user', 'prompt']:
            if key in log:
                print(f"\n🎤 {key.upper()} PROMPT:")
                print("-" * 100)
                prompt = log[key]
                if len(prompt) > 500:
                    print(prompt[:500] + "...(truncated)")
                else:
                    print(prompt)
        
        # Show output if available
        for key in ['output', 'rawOutput', 'response']:
            if key in log:
                print(f"\n📤 LLM OUTPUT:")
                print("-" * 100)
                output = log[key]
                try:
                    parsed = json.loads(output)
                    print(json.dumps(parsed, indent=2)[:1000])
                except:
                    if len(output) > 500:
                        print(output[:500] + "...(truncated)")
                    else:
                        print(output)
                break
        
        print()


def show_specific_trace_data(trace_id):
    """Show ALL data for a specific trace - everything"""
    print("=" * 100)
    print(f"🔍 COMPLETE TRACE DATA: {trace_id}")
    print("=" * 100)
    print()
    
    # Get all logs for this trace with JSON
    logs = run_klogs_json(["trace", trace_id, "--limit", "100"])
    
    if not logs:
        print(f"No logs found for trace {trace_id}")
        return
    
    print(f"Found {len(logs)} log entries\n")
    
    for i, log in enumerate(logs, 1):
        severity = log.get('severity', 'UNKNOWN')
        message = log.get('message', '')
        timestamp = log.get('timestamp', '')
        
        print(f"\n{'='*100}")
        print(f"#{i} [{severity}] {timestamp}")
        print(f"{'='*100}")
        print(f"Message: {message[:200]}")
        
        # Show interesting fields
        interesting = ['error', 'nodeID', 'lessonPlanID', 'model', 'provider', 
                      'rawOutput', 'output', 'rawJSON', 'correctedField', 
                      'attempt', 'backoff', 'promptType']
        
        for key in interesting:
            if key in log:
                value = log[key]
                print(f"\n📌 {key}:")
                print("-" * 80)
                if isinstance(value, (dict, list)):
                    print(json.dumps(value, indent=2)[:500])
                else:
                    val_str = str(value)
                    if len(val_str) > 500:
                        print(val_str[:500] + "...(truncated)")
                    else:
                        print(val_str)
        
        print()


def get_latest_failures_with_data(time_range="30m"):
    """Get latest failures and show EVERYTHING about them"""
    print("=" * 100)
    print("🔥 LATEST FAILURES - COMPLETE DATA DUMP")
    print("=" * 100)
    print()
    
    # Get recent errors
    logs = run_klogs_json(["search", "--severity", "ERROR", "--time", time_range, "--limit", "5"])
    
    if not logs:
        print("No errors found (lucky you!)")
        return
    
    for i, log in enumerate(logs, 1):
        print(f"\n{'='*100}")
        print(f"ERROR #{i}")
        print(f"{'='*100}")
        
        # Dump the entire log entry
        print(json.dumps(log, indent=2))
        
        # If there's a trace ID, offer to dig deeper
        if 'traceID' in log:
            print(f"\n💡 Want full trace? Run:")
            print(f"   ./llm-show-failures trace {log['traceID']}")
        
        print()


def print_usage():
    print("""
LLM Show Failures - See the ACTUAL data that failed

Commands:
  unmarshal [time]    - Show actual JSON that failed to unmarshal
  latex [time]        - Show actual bad LaTeX that failed validation
  prompts [time]      - Show actual LLM prompts and responses
  trace <trace-id>    - Show ALL data for a specific trace
  dump [time]         - Dump latest failures with complete data
  
Examples:
  ./llm-show-failures unmarshal          # Last 30min
  ./llm-show-failures latex --time 1h    # Last 1 hour
  ./llm-show-failures prompts --time 5m  # Last 5 min
  ./llm-show-failures trace abc123...    # Specific trace
  ./llm-show-failures dump               # Latest errors, full data

Time formats: 5m, 30m, 1h, 2h, 24h

This shows the REAL data - not summaries. The actual broken JSON, bad LaTeX, 
full prompts, raw outputs. Everything you need to debug.
""")


def main():
    if len(sys.argv) < 2:
        print_usage()
        sys.exit(1)
    
    command = sys.argv[1]
    
    # Parse time option
    time_range = "30m"
    if "--time" in sys.argv:
        idx = sys.argv.index("--time")
        if idx + 1 < len(sys.argv):
            time_range = sys.argv[idx + 1]
    
    if command == "unmarshal":
        show_raw_unmarshal_failures(time_range)
    elif command == "latex":
        show_raw_latex_failures(time_range)
    elif command == "prompts":
        limit = 5
        if len(sys.argv) > 2 and sys.argv[2].isdigit():
            limit = int(sys.argv[2])
        show_llm_prompt_responses(time_range, limit)
    elif command == "trace":
        if len(sys.argv) < 3:
            print("Error: trace command requires a trace ID")
            sys.exit(1)
        show_specific_trace_data(sys.argv[2])
    elif command == "dump":
        get_latest_failures_with_data(time_range)
    elif command in ["--help", "-h", "help"]:
        print_usage()
    else:
        print(f"Unknown command: {command}")
        print_usage()
        sys.exit(1)


if __name__ == "__main__":
    main()
