#!/usr/bin/env -S uv run --quiet --script
# /// script
# requires-python = ">=3.11"
# dependencies = [
#   "requests",
# ]
# ///
"""
klogs - Quick OpenTelemetry-aware log viewer for Kubernetes
Queries Loki API for flexible log searching
"""

import sys
import json
import time
from datetime import datetime, timedelta
from urllib.parse import urlencode
import argparse
import requests

# Configuration
LOKI_URL = "http://localhost:3100"
DEFAULT_NAMESPACE = "learning"
DEFAULT_POD = "resources-graphql"
DEFAULT_CONTEXT = "dev"

# Color codes for terminal
class Colors:
    RED = '\033[91m'
    GREEN = '\033[92m'
    YELLOW = '\033[93m'
    BLUE = '\033[94m'
    MAGENTA = '\033[95m'
    CYAN = '\033[96m'
    BOLD = '\033[1m'
    DIM = '\033[2m'
    END = '\033[0m'

def colorize(text, color):
    """Add color to text"""
    return f"{color}{text}{Colors.END}"

def parse_time_range(time_str):
    """Parse time range string like '5m', '1h', '2d' into timedelta"""
    if not time_str:
        return timedelta(hours=1)
    
    unit = time_str[-1]
    try:
        value = int(time_str[:-1])
    except ValueError:
        return timedelta(hours=1)
    
    if unit == 's':
        return timedelta(seconds=value)
    elif unit == 'm':
        return timedelta(minutes=value)
    elif unit == 'h':
        return timedelta(hours=value)
    elif unit == 'd':
        return timedelta(days=value)
    else:
        return timedelta(hours=1)

def build_logql_query(namespace, pod, severity=None, trace_id=None, message=None, user_id=None, caller=None):
    """Build LogQL query string"""
    # Base query
    query = f'{{namespace="{namespace}", pod=~"{pod}.*"}}'
    
    # Add JSON parsing and filters
    filters = []
    if severity:
        filters.append(f'severity="{severity}"')
    if trace_id:
        filters.append(f'traceID="{trace_id}"')
    if message:
        # Case-insensitive search
        filters.append(f'message=~"(?i).*{message}.*"')
    if user_id:
        filters.append(f'(userID="{user_id}" or authorizedUserID="{user_id}")')
    if caller:
        filters.append(f'caller=~".*{caller}.*"')
    
    if filters:
        query += ' | json | ' + ' | '.join(filters)
    
    return query

def format_log_entry(log_data, show_json=False, compact=False):
    """Format a single log entry for display"""
    try:
        log = json.loads(log_data) if isinstance(log_data, str) else log_data
    except (json.JSONDecodeError, TypeError):
        print(log_data)
        return
    
    severity = log.get('severity', 'INFO')
    timestamp = log.get('time', '')
    message = log.get('message', '')
    caller = log.get('caller', '')
    trace_id = log.get('traceID', '')
    user_id = log.get('userID', '') or log.get('authorizedUserID', '')
    
    # Format timestamp
    try:
        dt = datetime.fromisoformat(timestamp.replace('Z', '+00:00'))
        time_str = dt.strftime('%H:%M:%S')
    except:
        time_str = timestamp[:8] if timestamp else ''
    
    # Color based on severity
    severity_colors = {
        'ERROR': Colors.RED,
        'WARN': Colors.YELLOW,
        'WARNING': Colors.YELLOW,
        'INFO': Colors.GREEN,
        'DEBUG': Colors.BLUE,
    }
    color = severity_colors.get(severity, '')
    
    if compact:
        # Single line format
        parts = [
            colorize(f"{severity:5}", color),
            colorize(time_str, Colors.DIM),
            message[:100]
        ]
        print(' '.join(parts))
    else:
        # Multi-line format
        print(colorize(f"{severity:5}", color) + f" [{time_str}] {message}")
        
        if caller:
            print(f"       {colorize('↪', Colors.DIM)} {colorize(caller, Colors.DIM)}")
        
        if trace_id:
            print(f"       {colorize('🔗', Colors.CYAN)} Trace: {colorize(trace_id, Colors.CYAN)}")
        
        if user_id:
            print(f"       {colorize('👤', Colors.BLUE)} User: {colorize(user_id, Colors.BLUE)}")
    
    if show_json:
        print(colorize(json.dumps(log, indent=2), Colors.DIM))
    
    if not compact:
        print()  # Blank line between entries

def query_loki(query, time_range='1h', limit=50, direction='backward'):
    """Query Loki API"""
    end = datetime.now()
    start = end - parse_time_range(time_range)
    
    params = {
        'query': query,
        'start': int(start.timestamp() * 1e9),
        'end': int(end.timestamp() * 1e9),
        'limit': limit,
        'direction': direction
    }
    
    try:
        url = f"{LOKI_URL}/loki/api/v1/query_range"
        response = requests.get(url, params=params, timeout=30)
        response.raise_for_status()
        return response.json()
    except requests.exceptions.ConnectionError:
        print(colorize("❌ Cannot connect to Loki. Is port-forward running?", Colors.RED))
        print(colorize("\nRun this in another terminal:", Colors.YELLOW))
        print(colorize(f"  kubectl port-forward -n observability svc/loki-gateway 3100:80 --context={DEFAULT_CONTEXT}", Colors.BOLD))
        sys.exit(1)
    except requests.exceptions.RequestException as e:
        print(colorize(f"❌ Error querying Loki: {e}", Colors.RED))
        sys.exit(1)

def search_logs(args):
    """Search logs with filters"""
    query = build_logql_query(
        namespace=args.namespace,
        pod=args.pod,
        severity=args.severity,
        trace_id=args.trace,
        message=args.message,
        user_id=args.user,
        caller=args.caller
    )
    
    # Show query info
    print(colorize("🔍 Searching logs", Colors.BOLD))
    print(colorize(f"   Query: {query}", Colors.DIM))
    print(colorize(f"   Range: Last {args.time}", Colors.DIM))
    print()
    
    data = query_loki(query, time_range=args.time, limit=args.limit)
    
    # Extract and display logs
    logs = []
    for stream in data.get('data', {}).get('result', []):
        for value in stream.get('values', []):
            timestamp, log_line = value
            logs.append((timestamp, log_line))
    
    if not logs:
        print(colorize("No logs found", Colors.YELLOW))
        return
    
    # Sort by timestamp
    logs.sort(key=lambda x: x[0], reverse=(args.order == 'newest'))
    
    print(colorize(f"✓ Found {len(logs)} log(s)\n", Colors.GREEN))
    
    for i, (timestamp, log_line) in enumerate(logs):
        if args.limit and i >= args.limit:
            break
        format_log_entry(log_line, show_json=args.json, compact=args.compact)

def trace_logs(trace_id, args):
    """Show all logs for a trace ID"""
    query = f'{{namespace="{args.namespace}"}} | json | traceID="{trace_id}"'
    
    print(colorize(f"🔍 Trace: {trace_id}", Colors.CYAN))
    print()
    
    data = query_loki(query, time_range='2h', limit=1000)
    
    logs = []
    for stream in data.get('data', {}).get('result', []):
        for value in stream.get('values', []):
            timestamp, log_line = value
            logs.append((timestamp, log_line))
    
    if not logs:
        print(colorize("No logs found for this trace", Colors.YELLOW))
        return
    
    # Sort chronologically for traces
    logs.sort(key=lambda x: x[0])
    
    print(colorize(f"✓ Found {len(logs)} log(s) in this trace\n", Colors.GREEN))
    
    for timestamp, log_line in logs:
        format_log_entry(log_line, show_json=args.json)

def errors_only(args):
    """Show only ERROR logs"""
    query = f'{{namespace="{args.namespace}", pod=~"{args.pod}.*"}} | json | severity="ERROR"'
    
    print(colorize("❌ Recent ERROR logs\n", Colors.RED))
    
    data = query_loki(query, time_range=args.time, limit=args.limit)
    
    logs = []
    for stream in data.get('data', {}).get('result', []):
        for value in stream.get('values', []):
            timestamp, log_line = value
            logs.append((timestamp, log_line))
    
    if not logs:
        print(colorize("✓ No errors found!", Colors.GREEN))
        return
    
    logs.sort(key=lambda x: x[0], reverse=True)
    
    print(colorize(f"Found {len(logs)} error(s)\n", Colors.RED))
    
    for timestamp, log_line in logs:
        format_log_entry(log_line, show_json=args.json, compact=args.compact)

def tail_logs(args):
    """Tail logs in real-time"""
    print(colorize(f"📡 Tailing logs from {args.pod} (Ctrl+C to stop)...\n", Colors.YELLOW))
    
    query = build_logql_query(
        namespace=args.namespace,
        pod=args.pod,
        severity=args.severity,
        message=args.message
    )
    
    try:
        # Use tail API
        params = {
            'query': query,
            'limit': 100
        }
        
        url = f"{LOKI_URL}/loki/api/v1/tail"
        
        # Note: Loki tail endpoint requires WebSocket, so we'll poll instead
        print(colorize("(Using polling mode - checking every 2 seconds)", Colors.DIM))
        print()
        
        last_timestamp = int(datetime.now().timestamp() * 1e9)
        
        while True:
            # Query for new logs since last check
            now = int(datetime.now().timestamp() * 1e9)
            params = {
                'query': query,
                'start': last_timestamp,
                'end': now,
                'limit': 100
            }
            
            response = requests.get(f"{LOKI_URL}/loki/api/v1/query_range", params=params)
            if response.ok:
                data = response.json()
                logs = []
                for stream in data.get('data', {}).get('result', []):
                    for value in stream.get('values', []):
                        timestamp, log_line = value
                        if int(timestamp) > last_timestamp:
                            logs.append((timestamp, log_line))
                
                # Sort and display
                logs.sort(key=lambda x: x[0])
                for timestamp, log_line in logs:
                    format_log_entry(log_line, compact=True)
                    last_timestamp = max(last_timestamp, int(timestamp))
            
            time.sleep(2)
    
    except KeyboardInterrupt:
        print(colorize("\n\nStopped tailing", Colors.YELLOW))

def main():
    parser = argparse.ArgumentParser(
        description='klogs - Quick log viewer for Kubernetes + Loki',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  klogs search --severity ERROR --time 30m       # Errors in last 30 minutes
  klogs search --message "email" --limit 20      # Search for "email" 
  klogs search --user abc123                     # Logs for specific user
  klogs trace bb1af7a14944e6a41f1d2697522e822b  # All logs in a trace
  klogs errors                                   # Recent errors
  klogs tail --severity ERROR                    # Live error stream
        """
    )
    
    subparsers = parser.add_subparsers(dest='command', help='Command to run')
    
    # Search command
    search_parser = subparsers.add_parser('search', help='Search logs')
    search_parser.add_argument('--severity', choices=['DEBUG', 'INFO', 'WARN', 'ERROR'], help='Filter by severity')
    search_parser.add_argument('--message', '-m', help='Search in message text')
    search_parser.add_argument('--trace', '-t', help='Filter by trace ID')
    search_parser.add_argument('--user', '-u', help='Filter by user ID')
    search_parser.add_argument('--caller', '-c', help='Filter by caller (file:line)')
    search_parser.add_argument('--time', default='1h', help='Time range (e.g., 5m, 1h, 24h)')
    search_parser.add_argument('--limit', '-l', type=int, default=50, help='Max results')
    search_parser.add_argument('--json', '-j', action='store_true', help='Show full JSON')
    search_parser.add_argument('--compact', action='store_true', help='Compact output')
    search_parser.add_argument('--order', choices=['newest', 'oldest'], default='oldest', help='Sort order')
    search_parser.add_argument('--namespace', default=DEFAULT_NAMESPACE, help='Kubernetes namespace')
    search_parser.add_argument('--pod', default=DEFAULT_POD, help='Pod name prefix')
    
    # Trace command
    trace_parser = subparsers.add_parser('trace', help='Show all logs for a trace ID')
    trace_parser.add_argument('trace_id', help='OpenTelemetry trace ID')
    trace_parser.add_argument('--json', '-j', action='store_true', help='Show full JSON')
    trace_parser.add_argument('--namespace', default=DEFAULT_NAMESPACE, help='Kubernetes namespace')
    
    # Errors command
    errors_parser = subparsers.add_parser('errors', help='Show recent ERROR logs')
    errors_parser.add_argument('--time', default='1h', help='Time range')
    errors_parser.add_argument('--limit', '-l', type=int, default=50, help='Max results')
    errors_parser.add_argument('--json', '-j', action='store_true', help='Show full JSON')
    errors_parser.add_argument('--compact', action='store_true', help='Compact output')
    errors_parser.add_argument('--namespace', default=DEFAULT_NAMESPACE, help='Kubernetes namespace')
    errors_parser.add_argument('--pod', default=DEFAULT_POD, help='Pod name prefix')
    
    # Tail command
    tail_parser = subparsers.add_parser('tail', help='Tail logs in real-time')
    tail_parser.add_argument('--severity', choices=['DEBUG', 'INFO', 'WARN', 'ERROR'], help='Filter by severity')
    tail_parser.add_argument('--message', '-m', help='Search in message text')
    tail_parser.add_argument('--namespace', default=DEFAULT_NAMESPACE, help='Kubernetes namespace')
    tail_parser.add_argument('--pod', default=DEFAULT_POD, help='Pod name prefix')
    
    args = parser.parse_args()
    
    if not args.command:
        parser.print_help()
        sys.exit(1)
    
    # Execute command
    if args.command == 'search':
        search_logs(args)
    elif args.command == 'trace':
        trace_logs(args.trace_id, args)
    elif args.command == 'errors':
        errors_only(args)
    elif args.command == 'tail':
        tail_logs(args)

if __name__ == '__main__':
    main()
