#include <xmtc.h>
#include "profile.h"
#include "lib/arith.h"
#include "lib/fi.h"
#include "lib/shift.h"
#include "lib/sum.h"

psBaseReg psbr1;
psBaseReg psbr2;

#define USE_KSPAWN
#include "kspawn.h"

#ifndef UCHAR_MAX
#define UCHAR_MAX 255
#endif

#define ALPHABET_SIZE (UCHAR_MAX + 1)

#define SWAP3(w, x, y) do { w = x; x = y; y = w; } while (0)
#define SWAP(x, y) do { typeof(x) temp; SWAP3(temp, x, y); } while (0)


#define RUNA 0
#define RUNB 1

#define BITS_PER_WORD (sizeof(int) * 8)
#define ALPHASET_WORDS (ALPHABET_SIZE / BITS_PER_WORD)
typedef int alphaset_t[ALPHASET_WORDS];

void alphaset_clear(alphaset_t s) {
	int i;
	for (i = 0; i < ALPHASET_WORDS; i++)
		s[i] = 0;
}
void __inline__ alphaset_add(alphaset_t s, int c) {
	s[c / BITS_PER_WORD] |= 1 << (c % BITS_PER_WORD);
}
int __inline__ alphaset_test(alphaset_t s, int c) {
	return (s[c / BITS_PER_WORD] & 1 << (c % BITS_PER_WORD)) != 0;
}

typedef struct {
	int size;
	int entries[ALPHABET_SIZE];
} alpha_table_f;

typedef struct {
	int entries[ALPHABET_SIZE];
} alpha_table_r;

int fwd_alpha_table(alpha_table_f *table, int symbol, int *charset, int n_charset);
void fwd_alpha_table_append(alpha_table_f *table, int symbol);

void alpha_set(alpha_table_f *l, const alpha_table_f *r) {
	if (l != r) {
		int i;
		int size = l->size = r->size;
		for (i = 0; i < size; i++)
			l->entries[i] = r->entries[i];
	}
}

void alpha_setadd(alpha_table_f *l, const alpha_table_f *x, const alpha_table_f *y) {
	int i;

	if (l == x) {
	/*
		if (l->size > 8 && y->size > 8) {
			alphaset_t yset;
			alphaset_clear(yset);
			for (i = 0; i < y->size; i++) {
				alphaset_add(yset, y->entries[i]);
			}
			int newsize = l->size;
			for (i = 0; i < l->size; i++) {
				if (!alphaset_test(yset, l->entries[i]))
					newsize++;
			}
			int newindex = newsize - 1;
			for (i = l->size - 1; i >= 0; i--) {
				int c = l->entries[i];
				if (!alphaset_test(yset, c))
					l->entries[newindex--] = c;
			}
			for (i = 0; i < y->size; i++)
				l->entries[i] = y->entries[i];
		}
	*/
		for (i = y->size - 1; i >= 0; i--)
			fwd_alpha_table(l, y->entries[i], NULL, 0);
	}
	else {
		alpha_set(l, y);
		if (l->size > 6 && x->size > 8) {
			alphaset_t lset;
			alphaset_clear(lset);
			for (i = 0; i < l->size; i++) {
				alphaset_add(lset, l->entries[i]);
			}
			for (i = 0; i < x->size; i++) {
				int c = x->entries[i];
				if (!alphaset_test(lset, c))
					l->entries[l->size++] = c;
			}
		}
		else {
			for (i = 0; i < x->size; i++)
				fwd_alpha_table_append(l, x->entries[i]);
		}
	}
}

#define T alpha_table_f
#define T_SET(l, r) alpha_set(&(l), &(r))
#define T_SETADD(l, x, y) alpha_setadd(&(l), &(x), &(y))
DECLARE_PREFIX_SUM(alpha_table_f)
#include "lib/psum_inplace_inc.c"
#undef T
#undef T_SET
#undef T_SETADD

void init_alpha_table_r(alpha_table_r *table, int n_used_chars) {
	int i;
	int *table_entries = table->entries;
	for (i = 0; i < n_used_chars; i++)
		table_entries[i] = i;
}

void alpha_r_set(alpha_table_r *l, const alpha_table_r *r, int n_used_chars) {
	if (l != r) {
		int i;
		for (i = 0; i < n_used_chars; i++)
			l->entries[i] = r->entries[i];
	}
}

void alpha_r_setadd(alpha_table_r *l, const alpha_table_r *x, const alpha_table_r *y, int n_used_chars) {
	int i;

	if (l != x) {
		for (i = 0; i < n_used_chars; i++)
			l->entries[i] = x->entries[y->entries[i]];
	}
	else {
		alpha_table_r temp;
		alpha_r_set(&temp, x, n_used_chars);
		alpha_r_setadd(l, &temp, y, n_used_chars);
	}
}

int global_n_used_chars;

#define T alpha_table_r
#define T_SET(l, r) alpha_r_set(&(l), &(r), global_n_used_chars)
#define T_SETADD(l, x, y) alpha_r_setadd(&(l), &(x), &(y), global_n_used_chars)
DECLARE_PREFIX_SUM(alpha_table_r)
#include "lib/psum_inplace_inc.c"
#undef T
#undef T_SET
#undef T_SETADD

int fwd_alpha_table(alpha_table_f *table, int symbol, int *charset, int n_charset) {
	int i;
	int last = symbol;
	int n_ahead = 0;

	for (i = 0; i < table->size; i++) {
		SWAP(table->entries[i], last);
		if (last == symbol)
			return i;
		if (last > symbol)
			n_ahead++;
	}

	table->entries[table->size] = last;
	table->size++;

	for (i = 0; i < n_charset; i++) {
		if (charset[i] == symbol)
			break;
	}
	return i + n_ahead;
}

void fwd_alpha_table_append(alpha_table_f *table, int symbol) {
	int i;
	for (i = 0; i < table->size; i++) {
		if (table->entries[i] == symbol)
			return;
	}

	table->entries[table->size] = symbol;
	table->size++;
}

int rev_alpha_table(alpha_table_r *table, int index) {
	int i;
	int symbol = table->entries[index];

	for (i = index; i > 0; i--)
		table->entries[i] = table->entries[i - 1];
	table->entries[0] = symbol;
	return symbol;
}

int mtf_encode(const int *input, int *output, int n, int *used_chars, int *n_used_chars) {
	int k = 256;

	STARTTIME();
	int i;
	int nb = CDIV(n, k);
	alpha_table_f tables[nb + 1];
	int block_sizes[nb + 1];
	flagged_int rle_counts[nb + 1];

	rle_counts[0] = 1;
	spawn(1, nb) {
		rle_counts[$] = 0;
	}

	tables[0].size = 0;
	spawn(0, nb - 1) {
		int i;
		int begin = $ * k;
		int end = begin + k;
		if (end > n)
			end = n;
		int rle_char = $ > 0 ? input[begin - 1] : 0;
		int rle_count = 0;
		int block_size = 0;
		int first_run = 1;

		alpha_table_f *table = &tables[$ + 1];
		table->size = 0;

		for (i = begin; i < end; i++) {
			int ch = input[i];
			if (ch == rle_char) {
				rle_count++;
			}
			else {
				if (first_run) {
					psm(rle_count, rle_counts[$]);
					rle_count = 1;
					first_run = 0;
				}
				else {
					while (rle_count > 1) {
						block_size++;
						rle_count >>= 1;
					}
				}
				rle_char = ch;
				{
					int j;
					int last = ch;
					int n_ahead = 0;

					for (j = 0; j < table->size; j++) {
						SWAP(table->entries[j], last);
						if (last == ch)
							goto cont;
						if (last > ch)
							n_ahead++;
					}

					table->entries[table->size] = last;
					table->size++;
				cont:
					j = j;
				}
				block_size++;
			}
		}

		block_sizes[$ + 1] = block_size;
		if (block_size == 0)
			rle_count = FI_SET_FLAG(rle_count);
		psm(rle_count, rle_counts[$ + 1]);
	}
	SHOWTIME("pass1");

	prefix_sum_flagged_int(rle_counts, rle_counts, nb + 1);

	spawn(0, nb - 1) {
		int size = block_sizes[$ + 1];
		if (size != 0) {
			int count = rle_counts[$];
			if (count > 1)
				block_sizes[$ + 1] = size + floor_log2(count);
		}
	}

	block_sizes[0] = 0;
	prefix_sum_int(block_sizes + 1, block_sizes + 1, nb);
	SHOWTIME("rle_psum");
	
	#if 1
	prefix_sum_alpha_table_f(tables, tables, nb + 1);
	#else
	alpha_table_f work_table;

	alpha_set(&work_table, &tables[0]);
	for (i = 1; i <= nb; i++) {
		alpha_setadd(&work_table, &work_table, &tables[i]);
		alpha_set(&tables[i], &work_table);
	}
	#endif
	SHOWTIME("sum");

	for (i = 0; i < 256; i++)
		used_chars[i] = 0;
	for (i = 0; i < tables[nb].size; i++)
		used_chars[tables[nb].entries[i]] = 1;

	int charset[256];
	int n_charset = 0;
	for (i = 0; i < 256; i++) {
		if (used_chars[i])
			charset[n_charset++] = i;
	}
	*n_used_chars = n_charset + 2;
	SHOWTIME("charset");

	spawn(0, nb - 1) {
		int i;
		int begin = $ * k;
		int end = begin + k;
		if (end > n)
			end = n;
		int rle_char = $ > 0 ? input[begin - 1] : 0;
		int rle_count = rle_counts[$];
		int pos = block_sizes[$];

		alpha_table_f *table = &tables[$];

		// Skip the first run if it is the continuation of the preceeding run
		// Also, ignore blocks consisting only of a continuation
		for (i = begin; i < end; i++) {
			if (input[i] != rle_char)
				break;
		}

		for (; i < end; i++) {
			int ch = input[i];
			if (ch == rle_char) {
				rle_count++;
			}
			else {
				while (rle_count > 1) {
					output[pos++] = ((rle_count & 1) == 0) ? RUNA : RUNB;
					rle_count >>= 1;
				}
				rle_char = ch;
				output[pos++] = fwd_alpha_table(table, ch, charset, n_charset) + 1;
			}
		}
	}

	int rle_count = rle_counts[nb];
	int pos = block_sizes[nb];
	while (rle_count > 1) {
		output[pos++] = ((rle_count & 1) == 0) ? RUNA : RUNB;
		rle_count >>= 1;
	}

	output[pos++] = n_charset + 1; // end-of-stream marker

	SHOWTIME("pass2");

	return pos;
}

#define RLE_SPLIT_THRESHOLD 128
int mtf_decode(const int *input, int *output, int n, int *used_chars) {
	STARTTIME();

	const int k = 256;

	int nb = CDIV(n, k);
	alpha_table_r tables[nb + 1];
	int output_sizes[nb + 1];

	int i;
	int n_charset = 0;
	for (i = 0; i < 256; i++) {
		if (used_chars[i])
			tables[0].entries[n_charset++] = i;
	}
	SHOWTIME("charset");

	spawn(0, nb - 1) {
		int i;
		int begin = $ * k;
		int end = begin + k;
		if (end > n)
			end = n;
		int size = 0;

		alpha_table_r *table = &tables[$ + 1];
		init_alpha_table_r(table, n_charset);

		int runlen = 0;
		int runlen_bits = 0;

		// Skip RUN characters at the beginning of a block if they are the
		// continuation of a series of RUN characters that begins in the
		// preceding block
		i = begin;
		if (begin > 0 && (input[begin - 1] == RUNA || input[begin - 1] == RUNB)) {
			int ch = input[i];
			while (ch == RUNA || ch == RUNB) {
				i++;
				ch = input[i];
			}
		}

		for ( ; i < end; i++) {
			int ch = input[i];
			if (ch == RUNA || ch == RUNB) {
				SRL(runlen, runlen, 1);
				runlen_bits++;
				if (ch == RUNB)
					runlen |= 0x80000000;
			}
			else {
				if (runlen_bits > 0) {
					SRL(runlen, runlen, 1);
					runlen |= 0x80000000;
					SRLV(runlen, runlen, 31 - runlen_bits);
					runlen_bits = 0;
					//printf("run of %dx\n", runlen);
					size += runlen - 1;
					runlen = 0;
				}
				/* output[i] = */ rev_alpha_table(table, ch - 1);
				size++;
			}
		}

		// Still inside of a run
		while (runlen_bits > 0) {
			int ch = i < n ? input[i] : 2;
			if (ch == RUNA || ch == RUNB) {
				SRL(runlen, runlen, 1);
				runlen_bits++;
				if (ch == RUNB)
					runlen |= 0x80000000;
			}
			else {
				SRL(runlen, runlen, 1);
				runlen |= 0x80000000;
				SRLV(runlen, runlen, 31 - runlen_bits);
				runlen_bits = 0;
				//printf("run of %dx\n", runlen);
				size += runlen - 1;
				runlen = 0;
			}
			i++;
		}

		output_sizes[$ + 1] = size;
	}
	SHOWTIME("pass1");

	output_sizes[0] = 0;
	prefix_sum_int(output_sizes + 1, output_sizes + 1, nb);
	#if 1
	global_n_used_chars = n_charset;
	prefix_sum_alpha_table_r(tables, tables, nb);
	#else
	for (i = 1; i < nb; i++)
		alpha_r_setadd(&tables[i], &tables[i - 1], &tables[i], n_charset);
	#endif
	SHOWTIME("sum");

	int run_begin[n];
	int run_length[n];
	int run_char[n];
	#define n_runs_base psbr1
	n_runs_base = 0;

	spawn(0, nb - 1) {
		int i;
		int begin = $ * k;
		int end = begin + k;
		if (end > n)
			end = n;
		int pos = output_sizes[$];

		alpha_table_r *table = &tables[$];

		int last_char = table->entries[0];
		int runlen = 0;
		int runlen_bits = 0;

		// Skip RUN characters at the beginning of a block if they are the
		// continuation of a series of RUN characters that begins in the
		// preceding block
		i = begin;
		if (begin > 0 && (input[begin - 1] == RUNA || input[begin - 1] == RUNB)) {
			int ch = input[i];
			while (ch == RUNA || ch == RUNB) {
				i++;
				ch = input[i];
			}
		}

		for ( ; i < end; i++) {
			int ch = input[i];
			if (ch == RUNA || ch == RUNB) {
				SRL(runlen, runlen, 1);
				runlen_bits++;
				if (ch == RUNB)
					runlen |= 0x80000000;
			}
			else {
				if (runlen_bits > 0) {
					SRL(runlen, runlen, 1);
					runlen |= 0x80000000;
					SRLV(runlen, runlen, 31 - runlen_bits);
					runlen--;
					runlen_bits = 0;
					//printf("run of %dx %d\n", runlen, last_char);
					if (runlen < RLE_SPLIT_THRESHOLD) {
						for (; runlen > 0; runlen--)
							output[pos++] = last_char;
					}
					else {
						int irun = 1;
						ps(irun, n_runs_base);
						run_begin[irun] = pos;
						run_length[irun] = runlen;
						run_char[irun] = last_char;

						pos += runlen;
						runlen = 0;
					}
				}
				output[pos++] = last_char = rev_alpha_table(table, ch - 1);
			}
		}

		// Still inside of a run
		while (runlen_bits > 0) {
			int ch = i < n ? input[i] : 2;
			if (ch == RUNA || ch == RUNB) {
				SRL(runlen, runlen, 1);
				runlen_bits++;
				if (ch == RUNB)
					runlen |= 0x80000000;
			}
			else {
				SRL(runlen, runlen, 1);
				runlen |= 0x80000000;
				SRLV(runlen, runlen, 31 - runlen_bits);
				runlen--;
				runlen_bits = 0;
				//printf("run of %dx %d\n", runlen, last_char);
				if (runlen < RLE_SPLIT_THRESHOLD) {
					for (; runlen > 0; runlen--)
						output[pos++] = last_char;
				}
				else {
					int irun  = 1;
					ps(irun, n_runs_base);
					run_begin[irun] = pos;
					run_length[irun] = runlen;
					run_char[irun] = last_char;
				}
			}
			i++;
		}
	}
	SHOWTIME("pass2");

	if (n_runs_base > 0) {
		int n_runs;
		do {
			n_runs = n_runs_base;
			spawn(0, n_runs - 1) {
				int len = run_length[$];
				if (len > RLE_SPLIT_THRESHOLD) {
					int i = 1;
					ps(i, n_runs_base);
					run_begin[i] = run_begin[$] + len / 2;
					run_length[i] = len - len / 2;
					run_char[i] = run_char[$];
					run_length[$] = len / 2;
				}
			}
		} while (n_runs != n_runs_base);

		spawn(0, n_runs - 1) {
			int begin = run_begin[$];
			int end = begin + run_length[$];
			int ch = run_char[$];
			int i;
			for (i = begin; i < end; i++)
				output[i] = ch;
		}
	}
	SHOWTIME("long_runs");

	return output_sizes[nb];
}
