#ifndef BYTESTREAM_H
#define BYTESTREAM_H

#include "memutil.h"
#include "shift.h"

typedef struct {
	int *data;
	int bufpos;
	int pos;
	int buf;
} bytestream_t;

#define BYTE_STREAM_TELL(s) ((s).bufpos)

#define BYTE_STREAM_INIT_READ(s, data_arg, pos_arg) do { \
	s.data = (data_arg); \
	s.bufpos = (pos_arg); \
	s.pos = ALIGN_DOWN(s.bufpos + 4, 4); \
	s.buf = s.data[s.bufpos >> 2] >> (32 - 8 * (s.pos - s.bufpos)); \
} while (0)

#define BYTE_STREAM_READ(s, dest) do { \
	if (s.bufpos == s.pos) { \
		s.buf = WORD_AT_BYTE_OFFSET(s.data, s.pos); \
		s.pos += 4; \
	} \
	dest = s.buf & 0xff; \
	SRL(s.buf, s.buf, 8); \
	s.bufpos++; \
} while (0)

#define BYTE_STREAM_READ4(s, d0, d1, d2, d3) do { \
	if (s.bufpos == s.pos) { \
		s.buf = WORD_AT_BYTE_OFFSET(s.data, s.pos); \
		s.pos += 4; \
		d0 = s.buf & 0xff; \
		SRL(s.buf, s.buf, 8); \
		d1 = s.buf & 0xff; \
		SRL(s.buf, s.buf, 8); \
		d2 = s.buf & 0xff; \
		SRL(s.buf, s.buf, 8); \
		d3 = s.buf & 0xff; \
		s.bufpos += 4; \
	} \
	else { \
		d0 = s.buf & 0xff; \
		SRL(s.buf, s.buf, 8); \
		s.bufpos++; \
		BYTE_STREAM_READ(s, d1); \
		BYTE_STREAM_READ(s, d2); \
		BYTE_STREAM_READ(s, d3); \
	} \
} while (0)

#define BYTE_STREAM_INIT_WRITE(s, data_arg, pos_arg) do { \
	s.data = (data_arg); \
	s.bufpos = (pos_arg); \
	s.pos = s.bufpos; \
} while (0);

static int __inline__ byte_stream_flush(bytestream_t s) {
	int n_bytes = s.bufpos - s.pos;
	if (n_bytes == 4) {
		s.data[s.pos >> 2] = s.buf;
	}
	else if (n_bytes > 0) {
		int mask = (1 << 8 * n_bytes) - 1 << 8 * (s.pos & 3);
		int old = s.data[s.pos >> 2] & mask;
		int new = s.buf >> 8 * (4 - s.bufpos & 3) & mask;
		#ifdef __XMTC_2_OPENMP__
		// Valgrind/Memcheck doesn't know how bits are updated by the atomic operation below
		s.data[s.pos >> 2] = s.data[s.pos >> 2] & ~mask | new;
		#else
		int incr = new - old;
		psm(incr, s.data[s.pos >> 2]);
		#endif
	}
	return s.bufpos;
}

#define BYTE_STREAM_FLUSH(s) (s.pos = byte_stream_flush(s))

#define BYTE_STREAM_WRITE(s, value) do { \
	SRL(s.buf, s.buf, 8); \
	s.buf |= (value) << 24; \
	s.bufpos++; \
	if ((s.bufpos & 3) == 0) \
		BYTE_STREAM_FLUSH(s); \
} while (0)

#endif
