/**
 * Implementation of the synchronization primitives.
 *
 * Implemented here to avoid re-compilation all the time, and issues with us having references to
 * vtables of deallocated classes.
 */
use core:sync;
use lang:bs:unsafe;

/**
 * Exception.
 */
class SyncError extends progvis:TypedError {
	init(Str msg) {
		init("sync usage") { msg = msg; }
		progvis:program:onFatalException();
	}

	Str msg;

	void message(StrBuf to) : override {
		to << msg;
	}
}

/**
 * Base class for all synchronization primitives.
 *
 * Allows the runtime to extract the queue of waiting threads easily.
 */
class Waitable {
	// Get threads that are currently waiting.
	Nat[] waitingThreads() : abstract;
}

// Check that we are running on the Render thread.
void checkRenderThread() {
	if (!ui:Render.isCurrent)
		throw SyncError("This primitive can only be used from the Render thread in the program visualization!");
}

/**
 * Internals of a semaphore that is friendly to the visualization system. If regular semaphores were
 * used, we won't be able to abort the system when a thread is waiting for a semaphore.
 *
 * Note: We know that multiple HW threads won't use the same instance, so we can actually get by without locks!
 */
class SemaImpl extends Waitable {
	Nat count;

	// All threads waiting for this sema.
	Queue<progvis:program:ProgThread> waiting;

	init(Nat count) {
		init { count = count; }
	}

	void down(Variant waitFor) {
		checkRenderThread();
		as thread ui:Render {
			unless (thread = progvis:program:findThread(currentUThread())) {
				throw SyncError("This primitive is only intended to be used within the program visualization!");
			}

			if (count > 0) {
				--count;
				return;
			}

			// We need to wait!
			waiting.push(thread);

			// Tell the thread to sleep last, otherwise we may have thread switches in the critical region.
			// Note: This is a thread call. We might want to avoid that for better performance.
			thread.lockSleep(unsafe:RawPtr(waitFor));
		}
	}

	void up() {
		checkRenderThread();
		as thread ui:Render {
			if (waiting.any()) {
				var wake = waiting.top();
				waiting.pop();

				wake.lockWake();
			} else {
				count++;
			}
		}
	}

	Nat[] waitingThreads() : override {
		Nat[] result;
		for (x in waiting)
			result << x.threadId;
		result;
	}
}

/**
 * Condition implementation.
 */
class CondImpl extends Waitable {
	// All threads waiting.
	Queue<progvis:program:ProgThread> waiting;

	// Get number of waiting threads.
	Int waitingCount() {
		waiting.count.int;
	}

	// Wait for signal. Assuming the lock is held by us.
	void wait(Variant condVal, Variant lockVal, SemaImpl lock) {
		checkRenderThread();
		as thread ui:Render {
			unless (thread = progvis:program:findThread(currentUThread())) {
				throw SyncError("This primitive is only intended to be used within the program visualization!");
			}

			// This is actually safe, due to the threading model in Storm.
			lock.up();

			waiting.push(thread);
			thread.lockSleep(unsafe:RawPtr(condVal));

			// Grab the lock again.
			lock.down(lockVal);
		}
	}

	// Signal one.
	void signal(SemaImpl lock) {
		checkRenderThread();
		as thread ui:Render {
			if (waiting.any) {
				var wake = waiting.top();
				waiting.pop();

				wake.lockWake();
			}
		}
	}

	// Signal all.
	void broadcast(SemaImpl lock) {
		checkRenderThread();
		as thread ui:Render {
			while (waiting.any) {
				var wake = waiting.top();
				waiting.pop();

				wake.lockWake();
			}
		}
	}

	// Get waiting threads.
	Nat[] waitingThreads() : override {
		Nat[] result;
		for (x in waiting)
			result << x.threadId;
		result;
	}
}
