import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import odeint

def kuramoto_derivative(theta, t, omega, K, N):
    """
    Computes the derivative d(theta)/dt for the Kuramoto model.
    """
    # Create a 2D matrix of phase differences using broadcasting
    # theta[None, :] is a row (1, N)
    # theta[:, None] is a column (N, 1)
    # The result is an (N, N) matrix where element (i, j) is theta_j - theta_i
    phase_diff = theta[None, :] - theta[:, None]
    
    # Apply the interaction term: (K/N) * sum(sin(theta_j - theta_i))
    interaction = (K / N) * np.sum(np.sin(phase_diff), axis=1)
    
    # d(theta)/dt = omega + interaction
    dtheta_dt = omega + interaction
    return dtheta_dt

# --- 1. Set up Parameters ---
N = 50              # Number of oscillators
K = 2.0             # Coupling strength (K > K_critical leads to sync)
T = 10              # Total time
dt = 0.01           # Time step
time_points = np.arange(0, T, dt)

# Initialize natural frequencies (omega) and initial phases (theta0)
np.random.seed(42) # For reproducibility
omega = np.random.normal(loc=0.0, scale=1.0, size=N) # Normal distribution of frequencies
theta0 = np.random.uniform(0, 2*np.pi, size=N)       # Random initial phases

# --- 2. Solve the ODE ---
# args passes the extra parameters (omega, K, N) to the function
solution = odeint(kuramoto_derivative, theta0, time_points, args=(omega, K, N))

# --- 3. Visualization ---
# Plot sin(theta) to visualize the oscillation
plt.figure(figsize=(12, 6))

# Plot the evolution of the phases (wrapped to sin for clarity)
plt.plot(time_points, np.sin(solution), alpha=0.6)

plt.title(f"Kuramoto Model Simulation (N={N}, K={K})")
plt.xlabel("Time")
plt.ylabel("sin(theta_i)")
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# Optional: Calculate Order Parameter R(t) to measure synchronization
# R = | (1/N) * sum(e^(i*theta)) |
# If R approaches 1, they are synchronized.
order_parameter = np.abs(np.mean(np.exp(1j * solution), axis=1))

plt.figure(figsize=(12, 4))
plt.plot(time_points, order_parameter, color='red', linewidth=2)
plt.title("Synchronization Order Parameter R(t)")
plt.ylim(0, 1.1)
plt.xlabel("Time")
plt.ylabel("R (0=disorder, 1=sync)")
plt.grid(True)
plt.show()
