use std::{collections::{BTreeMap, BTreeSet}, f32, io::stdin, sync::{Arc, Mutex, MutexGuard, atomic::{AtomicU32, AtomicU64, AtomicUsize, Ordering}}};

use petgraph::{graph::NodeIndex, visit::EdgeRef};

type Distance = f32;
type Node = u32;
type Graph = petgraph::Graph<String, Distance, petgraph::Undirected>;

/// Find the length of the shortest path from the given node to all other nodes
/// in the graph. If the destination is not reachable from the starting node the
/// distance is 'Infinity'.
///
/// Nodes must be numbered [0..]
///
/// Negative edge weights are not supported.
///
/// NOTE: The type of the 'delta_stepping' function should not change (since that
/// is what the test suite expects), but you are free to change the types of all
/// other functions and data structures in this module as you require.
pub fn delta_stepping(
    // The number of threads we may use
    thread_count: usize,
    // Whether to print intermediate states to the console, for debugging purposes
    verbose: bool,
    // Graph to analyse
    graph: &Graph,
    // Delta (step width, bucket width)
    delta: Distance,
    // Index of the starting node
    source: Node,
) -> Vec<Distance> {
    todo!()
}

// Initialise algorithm state
fn initialise(
    graph: &Graph,
    delta: Distance,
    source: Node
) -> (Buckets, TentativeDistances) {
    todo!()
}

// Take a single step of the algorithm.
// That is, one iteration of the outer while loop.
// You may change the type of this function
fn step(
    verbose: bool,
    thread_count: usize,
    graph: &Graph,
    delta: Distance,
    buckets: &Buckets,
    distances: &mut TentativeDistances,
) -> (/* TODO */) {
    todo!()
}


// Once all buckets are empty, the tentative distances are finalised and the
//  algorithm terminates.
fn all_buckets_empty(
    buckets: &Buckets
) -> bool {
    todo!()
}


// Return the index of the first non-empty bucket. Assumes that there is at
//  least one non-empty bucket remaining.
fn find_next_bucket(
    buckets: &Buckets
) -> usize {
    todo!()
}


// Create requests of (node, distance) pairs that fulfil the given predicate
fn find_requests(
    thread_count: usize,
    light: bool,
    delta: Distance,
    graph: &Graph,
    verts: &BTreeSet<u32>,
    distances: &TentativeDistances
) -> (/* TODO */) {
    todo!()
}


// Execute requests for each of the given (node, distance) pairs
fn relax_requests(
    mut thread_count: usize,
    buckets: &Buckets,
    distances: &TentativeDistances,
    delta: Distance,
    req: (/* TODO */)
) {
    todo!()
}

// Execute a single relaxation, moving the given node to the appropriate bucket
//  as necessary
fn relax(
    buckets: &Buckets,
    distances: &TentativeDistances,
    delta: Distance,
    (node, new_distance): (Node, Distance)      // (w, x) in the paper
) {
    todo!()
}

// -----------------------------------------------------------------------------
// Starting framework
// -----------------------------------------------------------------------------
//
// Here are a collection of (data)types and utility functions that you can use.
// You are free to change these as necessary.
//

type TentativeDistances = Vec<Mutex<f32>>;


struct Buckets {
    first_bucket: AtomicU32, // Real index of the first bucket (j)
    bucket_array: Vec<Mutex<BTreeSet<u32>>> // Cyclic array of buckets
}


/// Forks the given number of threads and executes the given function in these
/// threads. The function gets a thread index as argument, between 0
/// (inclusive) and thread_count (exclusive).
/// This function returns after all threads have finished.
pub fn fork_threads<F: Fn(u32) -> () + Send + Sync>(thread_count: u32, f: F) {
    if thread_count == 1 {
        f(0);
        return;
    }

    std::thread::scope(|s| {
        for idx in 0 .. thread_count {
            let f_ref = &f;
            s.spawn(move || {
                f_ref(idx)
            });
        }
    });
}


fn print_verbose(
    verbose: bool,
    title: &str,
    graph: &Graph,
    delta: Distance,
    buckets: &Buckets,
    distances: &TentativeDistances
) {
    if verbose {
        println!("# {}", title);
        print_current_state(graph, distances);
        print_buckets(graph, delta, buckets, distances);
        println!("Press enter to continue");
        let _ = stdin().read_line(&mut String::new());
    }
}


// Print the current state of the algorithm (tentative distance to all nodes)
pub fn print_current_state(
    g: &Graph,
    distances: &TentativeDistances
) {
    println!("  Node  |  Label  |  Distance");
    println!("--------+---------+------------");

    for (idx, label) in g.node_weights().enumerate() {
        let node_dist = distances[idx as usize].lock().unwrap().clone();
        match node_dist {
            f32::INFINITY => { println!("  {idx:>4}  |  {label:>5}  |  -") }
            _ => { println!("  {idx:>4}  |  {label:>5}  |  {node_dist:?}") }
        }
    }
}


fn print_buckets(
    graph: &Graph,
    delta: Distance,
    buckets: &Buckets,
    distances: &TentativeDistances
) {
    let first = buckets.first_bucket.load(Ordering::Relaxed);

    for idx in 0..buckets.bucket_array.len() {
        let idx_ = first as usize + idx;
        println!("Bucket {idx_:?}: [{}, {})", idx_ as f32 * delta, (idx_ + 1) as f32 * delta);
        let buck = buckets.bucket_array.get(idx_ % buckets.bucket_array.len()).unwrap();
        print_bucket(graph, buck, distances);
    }
}


// Print the current bucket
fn print_current_bucket(
    graph: &Graph,
    delta: Distance,
    buckets: &Buckets,
    distances: &TentativeDistances
) {
    let first = buckets.first_bucket.load(Ordering::Relaxed) as usize;
    let buck = buckets.bucket_array.get(first % buckets.bucket_array.len()).unwrap();
    println!("Bucket {first}: [{}, {})", first as f32 * delta, (first + 1) as f32 * delta);
    print_bucket(&graph, buck, &distances);
}


// Print a given bucket
pub fn print_bucket(
    g: &Graph,
    bucket: &Mutex<BTreeSet<u32>>,
    distances: &TentativeDistances
) {
    println!("  Node  |  Label  |  Distance");
    println!("--------+---------+-----------");
    
    let guard = bucket.lock().unwrap();
    let mut nodes: Vec<u32> = guard.iter().copied().collect();
    drop(guard);

    nodes.sort_unstable();

    for node_idx in nodes {
        let node_dist = distances[node_idx as usize].lock().unwrap().clone();

        match g.node_weight(NodeIndex::new(node_idx as usize)) {
            Some(label) => {
                println!("  {node_idx:>4}  |  {label:>5}  |  {node_dist:?}",);
            }
            None => {
                println!("  {node_idx:>4}  |   -   |  {node_dist:?}",);
            }
        }
    }
}
