#include <pico/stdlib.h>
#include <pico/stdio_usb.h>
#include <pico/multicore.h>
#include <pico/util/queue.h>

#include <hardware/clocks.h>
#include <hardware/dma.h>
#include <hardware/gpio.h>
#include <hardware/pll.h>
#include <hardware/vreg.h>
#include <hardware/sync.h>
#include <hardware/pio.h>
#include <hardware/pwm.h>
#include <hardware/interp.h>

#include <hardware/regs/clocks.h>
#include <hardware/structs/bus_ctrl.h>

#include <math.h>
#include <stdio.h>
#include <limits.h>
#include <stdlib.h>

#define VREG_VOLTAGE VREG_VOLTAGE_1_20
#define CLK_SYS_HZ (300 * MHZ)

#define LO_PIN 9
#define RX_PIN 8
#define FB_PIN 5
#define PSU_PIN 23

#define PIO pio1
#define LO_SM 0
#define FB_SM 1
#define RX_SM 2
#define AD_SM 3

#define IQ_SAMPLES 32
#define IQ_BLOCK_LEN (2 * IQ_SAMPLES)
#define IQ_QUEUE_LEN 8

#define LO_BITS_DEPTH 15
#define LO_WORDS (1 << (LO_BITS_DEPTH - 2))
static uint32_t lo_cos[LO_WORDS] __attribute__((__aligned__(1 << LO_BITS_DEPTH)));

#define DECIMATE 4
#define RX_STRIDE (2 * IQ_SAMPLES * DECIMATE)
#define RX_BITS_DEPTH 13
#define RX_WORDS (1 << (RX_BITS_DEPTH - 2))

static_assert(RX_STRIDE * 4 <= RX_WORDS, "RX_STRIDE * 4 <= RX_WORDS");

static uint32_t rx_cos[RX_WORDS] __attribute__((__aligned__(1 << RX_BITS_DEPTH)));

#define INIT_SAMPLE_RATE 100000
#define INIT_FREQ 94600000
#define INIT_GAIN 127

#define NUM_GAINS 29
static int gains[NUM_GAINS] = { 0,   9,	  14,  27,  37,	 77,  87,  125, 144, 157,
				166, 197, 207, 229, 254, 280, 297, 328, 338, 364,
				372, 386, 402, 421, 434, 439, 445, 480, 496 };
static int sample_rate = INIT_SAMPLE_RATE;
static int gain = INIT_GAIN;

#define SIN_PHASE (0u)
#define COS_PHASE (3u << 30)

static int dma_ch_rx1 = -1;
static int dma_ch_rx2 = -1;

static int dma_ch_mix1 = -1;
static int dma_ch_mix2 = -1;

static int dma_ch_samp_cos = -1;

static int dma_t_samp = -1;

static int dma_ch_in_cos = -1;

static queue_t iq_queue;
static uint8_t iq_queue_buffer[IQ_QUEUE_LEN][IQ_BLOCK_LEN];
static size_t iq_queue_pos = 0;

static void dma_channel_clear_chain_to(int ch)
{
	uint32_t ctrl = dma_hw->ch[ch].al1_ctrl;
	ctrl &= ~DMA_CH0_CTRL_TRIG_CHAIN_TO_BITS;
	ctrl |= ch << DMA_CH0_CTRL_TRIG_CHAIN_TO_LSB;
	dma_hw->ch[ch].al1_ctrl = ctrl;
}

static void init_lo()
{
	gpio_disable_pulls(LO_PIN);
	pio_gpio_init(PIO, LO_PIN);

	gpio_set_drive_strength(LO_PIN, GPIO_DRIVE_STRENGTH_12MA);
	gpio_set_slew_rate(LO_PIN, GPIO_SLEW_RATE_FAST);

	const uint16_t insn[] = {
		pio_encode_out(pio_pindirs, 1),
	};

	pio_program_t prog = {
		.instructions = insn,
		.length = sizeof(insn) / sizeof(*insn),
		.origin = -1,
	};

	pio_sm_restart(PIO, LO_SM);
	pio_sm_clear_fifos(PIO, LO_SM);

	if (pio_can_add_program(PIO, &prog))
		prog.origin = pio_add_program(PIO, &prog);

	pio_sm_config pc = pio_get_default_sm_config();
	sm_config_set_out_pins(&pc, LO_PIN, 1);
	sm_config_set_set_pins(&pc, LO_PIN, 1);
	sm_config_set_wrap(&pc, prog.origin, prog.origin + prog.length - 1);
	sm_config_set_clkdiv_int_frac(&pc, 1, 0);
	sm_config_set_fifo_join(&pc, PIO_FIFO_JOIN_TX);
	sm_config_set_out_shift(&pc, false, true, 32);
	pio_sm_init(PIO, LO_SM, prog.origin, &pc);

	pio_sm_set_consecutive_pindirs(PIO, LO_SM, LO_PIN, 1, GPIO_OUT);
	pio_sm_exec_wait_blocking(PIO, LO_SM, pio_encode_set(pio_pins, 0));
}

static void init_fb()
{
	gpio_disable_pulls(FB_PIN);
	pio_gpio_init(PIO, FB_PIN);

	// NOTE: Not sure if this is ideal.
	hw_set_bits(&PIO->input_sync_bypass, 1u << RX_PIN);

	gpio_set_input_hysteresis_enabled(RX_PIN, false);
	gpio_set_drive_strength(FB_PIN, GPIO_DRIVE_STRENGTH_2MA);
	gpio_set_slew_rate(FB_PIN, GPIO_SLEW_RATE_SLOW);

	const uint16_t insn[] = {
		pio_encode_mov_not(pio_pins, pio_pins) | pio_encode_sideset(1, 1) |
			pio_encode_delay(0),
		//pio_encode_nop() | pio_encode_sideset(1, 0) | pio_encode_delay(0),
	};

	pio_program_t prog = {
		.instructions = insn,
		.length = sizeof(insn) / sizeof(*insn),
		.origin = -1,
	};

	pio_sm_restart(PIO, FB_SM);
	pio_sm_clear_fifos(PIO, FB_SM);

	if (pio_can_add_program(PIO, &prog))
		prog.origin = pio_add_program(PIO, &prog);

	pio_sm_config pc = pio_get_default_sm_config();
	sm_config_set_sideset(&pc, 1, false, true);
	sm_config_set_in_pins(&pc, RX_PIN);
	sm_config_set_out_pins(&pc, FB_PIN, 1);
	sm_config_set_set_pins(&pc, FB_PIN, 1);
	sm_config_set_sideset_pins(&pc, FB_PIN);
	sm_config_set_wrap(&pc, prog.origin, prog.origin + prog.length - 1);
	sm_config_set_clkdiv_int_frac(&pc, 1, 0);
	pio_sm_init(PIO, FB_SM, prog.origin, &pc);

	pio_sm_set_consecutive_pindirs(PIO, FB_SM, FB_PIN, 1, GPIO_OUT);
}

static void init_rx()
{
	gpio_disable_pulls(RX_PIN);
	pio_gpio_init(PIO, RX_PIN);

	const uint16_t insn[] = {
		pio_encode_in(pio_pins, 1) | pio_encode_delay(0),
	};

	pio_program_t prog = {
		.instructions = insn,
		.length = sizeof(insn) / sizeof(*insn),
		.origin = -1,
	};

	pio_sm_restart(PIO, RX_SM);
	pio_sm_clear_fifos(PIO, RX_SM);

	if (pio_can_add_program(PIO, &prog))
		prog.origin = pio_add_program(PIO, &prog);

	pio_sm_config pc = pio_get_default_sm_config();
	sm_config_set_in_pins(&pc, RX_PIN);
	sm_config_set_wrap(&pc, prog.origin, prog.origin + prog.length - 1);
	sm_config_set_clkdiv_int_frac(&pc, 1, 0);
	sm_config_set_fifo_join(&pc, PIO_FIFO_JOIN_RX);
	sm_config_set_in_shift(&pc, false, true, 32);
	pio_sm_init(PIO, RX_SM, prog.origin, &pc);

	pio_sm_set_consecutive_pindirs(PIO, RX_SM, RX_PIN, 1, GPIO_IN);
}

static void init_ad()
{
	const uint16_t insn[] = {
		pio_encode_jmp_y_dec(1),
		pio_encode_out(pio_pc, 2),
		pio_encode_out(pio_pc, 2),
		pio_encode_jmp_x_dec(2),

		/* Avoid Y-- on wrap. */
		pio_encode_out(pio_pc, 2),

		/*
		 * Should wrap here.
		 * Jump to this portion must be inserted from the outside.
		 */
		pio_encode_in(pio_x, 32),
		pio_encode_in(pio_y, 32),
		pio_encode_set(pio_x, 0),
		pio_encode_set(pio_y, 0),
		pio_encode_out(pio_pc, 2),
	};

	pio_program_t prog = {
		.instructions = insn,
		.length = sizeof(insn) / sizeof(*insn),
		.origin = 0,
	};

	pio_sm_restart(PIO, AD_SM);
	pio_sm_clear_fifos(PIO, AD_SM);

	if (pio_can_add_program(PIO, &prog))
		pio_add_program(PIO, &prog);

	pio_sm_config pc = pio_get_default_sm_config();
	sm_config_set_wrap(&pc, prog.origin, prog.origin + 5 - 1);
	sm_config_set_clkdiv_int_frac(&pc, 1, 0);
	sm_config_set_in_shift(&pc, false, true, 32);
	sm_config_set_out_shift(&pc, false, true, 32);
	pio_sm_init(PIO, AD_SM, prog.origin + 1, &pc);
}

#define STEP_BASE ((UINT_MAX + 1.0) / CLK_SYS_HZ)
static uint32_t freq_step = 1;

static void lo_generate(uint32_t *buf, double freq, uint32_t phase)
{
	freq_step = STEP_BASE * freq;

	for (size_t i = 0; i < LO_WORDS; i++) {
		uint32_t bits = 0;

		for (int j = 0; j < 32; j++) {
			bits |= phase >> 31;
			bits <<= 1;
			phase += freq_step;
		}

		buf[i] = bits;
	}
}

static void rx_lo_init(double req_freq, bool align)
{
	const double step_hz = (double)CLK_SYS_HZ / (8 << LO_BITS_DEPTH);
	double freq = req_freq;

	if (align)
		freq = round(freq / step_hz) * step_hz;

	lo_generate(lo_cos, freq, COS_PHASE);
}

static const uint32_t samp_insn = 5;

static void rf_rx_start()
{
	dma_ch_rx1 = dma_claim_unused_channel(true);
	dma_ch_rx2 = dma_claim_unused_channel(true);

	dma_ch_mix1 = dma_claim_unused_channel(true);
	dma_ch_mix2 = dma_claim_unused_channel(true);

	dma_ch_samp_cos = dma_claim_unused_channel(true);

	dma_t_samp = dma_claim_unused_timer(true);

	dma_channel_config dma_conf;

	/* Copy PDM bitstream into decimator. */
	dma_conf = dma_channel_get_default_config(dma_ch_rx1);
	channel_config_set_transfer_data_size(&dma_conf, DMA_SIZE_32);
	channel_config_set_read_increment(&dma_conf, false);
	channel_config_set_write_increment(&dma_conf, false);
	channel_config_set_dreq(&dma_conf, pio_get_dreq(PIO, RX_SM, GPIO_IN));
	channel_config_set_chain_to(&dma_conf, dma_ch_rx2);
	dma_channel_configure(dma_ch_rx1, &dma_conf, &PIO->txf[AD_SM], &PIO->rxf[RX_SM], UINT_MAX,
			      false);

	dma_conf = dma_channel_get_default_config(dma_ch_rx2);
	channel_config_set_transfer_data_size(&dma_conf, DMA_SIZE_32);
	channel_config_set_read_increment(&dma_conf, false);
	channel_config_set_write_increment(&dma_conf, false);
	channel_config_set_dreq(&dma_conf, pio_get_dreq(PIO, RX_SM, GPIO_IN));
	channel_config_set_chain_to(&dma_conf, dma_ch_rx1);
	dma_channel_configure(dma_ch_rx2, &dma_conf, &PIO->txf[AD_SM], &PIO->rxf[RX_SM], UINT_MAX,
			      false);

	/* Drive the LO capacitor. */
	dma_conf = dma_channel_get_default_config(dma_ch_mix1);
	channel_config_set_transfer_data_size(&dma_conf, DMA_SIZE_32);
	channel_config_set_read_increment(&dma_conf, true);
	channel_config_set_write_increment(&dma_conf, false);
	channel_config_set_ring(&dma_conf, GPIO_IN, LO_BITS_DEPTH);
	channel_config_set_dreq(&dma_conf, pio_get_dreq(PIO, LO_SM, GPIO_OUT));
	channel_config_set_chain_to(&dma_conf, dma_ch_mix2);
	dma_channel_configure(dma_ch_mix1, &dma_conf, &PIO->txf[LO_SM], lo_cos, UINT_MAX, false);

	dma_conf = dma_channel_get_default_config(dma_ch_mix2);
	channel_config_set_transfer_data_size(&dma_conf, DMA_SIZE_32);
	channel_config_set_read_increment(&dma_conf, true);
	channel_config_set_write_increment(&dma_conf, false);
	channel_config_set_ring(&dma_conf, GPIO_IN, LO_BITS_DEPTH);
	channel_config_set_dreq(&dma_conf, pio_get_dreq(PIO, LO_SM, GPIO_OUT));
	channel_config_set_chain_to(&dma_conf, dma_ch_mix1);
	dma_channel_configure(dma_ch_mix2, &dma_conf, &PIO->txf[LO_SM], lo_cos, UINT_MAX, false);

	/* Pacing timer for the sampling script trigger channel. */
	dma_timer_set_fraction(dma_t_samp, 1, CLK_SYS_HZ / (sample_rate * DECIMATE));

	/* Trigger accumulator values push. */
	dma_conf = dma_channel_get_default_config(dma_ch_samp_cos);
	channel_config_set_transfer_data_size(&dma_conf, DMA_SIZE_32);
	channel_config_set_read_increment(&dma_conf, false);
	channel_config_set_write_increment(&dma_conf, false);
	channel_config_set_high_priority(&dma_conf, true);
	channel_config_set_dreq(&dma_conf, dma_get_timer_dreq(dma_t_samp));
	dma_channel_configure(dma_ch_samp_cos, &dma_conf, &PIO->sm[AD_SM].instr, &samp_insn,
			      UINT_MAX, false);

	init_lo();
	init_fb();
	init_rx();
	init_ad();

	dma_channel_start(dma_ch_rx1);
	dma_channel_start(dma_ch_mix1);
	dma_channel_start(dma_ch_samp_cos);

	pio_set_sm_mask_enabled(PIO, 0x0f, true);
}

static void rf_rx_stop(void)
{
	pio_set_sm_mask_enabled(PIO, 0x0f, false);

	sleep_us(10);

	dma_channel_clear_chain_to(dma_ch_rx1);
	dma_channel_clear_chain_to(dma_ch_rx2);
	dma_channel_clear_chain_to(dma_ch_mix1);
	dma_channel_clear_chain_to(dma_ch_mix2);
	dma_channel_clear_chain_to(dma_ch_samp_cos);

	dma_channel_abort(dma_ch_rx1);
	dma_channel_abort(dma_ch_rx2);
	dma_channel_abort(dma_ch_mix1);
	dma_channel_abort(dma_ch_mix2);
	dma_channel_abort(dma_ch_samp_cos);

	dma_channel_cleanup(dma_ch_rx1);
	dma_channel_cleanup(dma_ch_rx2);
	dma_channel_cleanup(dma_ch_mix1);
	dma_channel_cleanup(dma_ch_mix2);
	dma_channel_cleanup(dma_ch_samp_cos);

	dma_channel_unclaim(dma_ch_rx1);
	dma_channel_unclaim(dma_ch_rx2);
	dma_channel_unclaim(dma_ch_mix1);
	dma_channel_unclaim(dma_ch_mix2);
	dma_channel_unclaim(dma_ch_samp_cos);

	dma_timer_unclaim(dma_t_samp);

	dma_ch_rx1 = -1;
	dma_ch_rx2 = -1;
	dma_ch_mix1 = -1;
	dma_ch_mix2 = -1;
	dma_ch_samp_cos = -1;

	dma_t_samp = -1;
}

struct IQ {
	int I, Q;
};

inline static struct IQ next_sample(const uint32_t *buf)
{
	static int h[11];

	int x3 = buf[0] - buf[1];
	int x2 = buf[2] - buf[3];
	int x1 = buf[5] - buf[4];
	int x0 = buf[7] - buf[6];

	const int c[] = { 4, 2, -8, -9, 19, 55 };

	int I = 0, Q = 0;

	Q += (c[0]) * h[10];
	I += (c[0] + c[1]) * h[9];
	Q += (c[0] + c[1] + c[2]) * h[8];
	I += (c[0] + c[1] + c[2] + c[3]) * h[7];
	Q += (c[1] + c[2] + c[3] + c[4]) * h[6];
	I += (c[2] + c[3] + c[4] + c[5]) * h[5];
	Q += (c[3] + c[4] + c[5] + c[5]) * h[4];
	I += (c[4] + c[5] + c[5] + c[4]) * h[3];
	Q += (c[5] + c[5] + c[4] + c[3]) * h[2];
	I += (c[5] + c[4] + c[3] + c[2]) * h[1];
	Q += (c[4] + c[3] + c[2] + c[1]) * h[0];
	I += (c[3] + c[2] + c[1] + c[0]) * x3;
	Q += (c[2] + c[1] + c[0]) * x2;
	I += (c[1] + c[0]) * x1;
	Q += (c[0]) * x0;

	I *= gain;
	I /= 128;

	Q *= gain;
	Q /= 128;

	h[10] = h[6];
	h[9] = h[5];
	h[8] = h[4];
	h[7] = h[3];
	h[6] = h[2];
	h[5] = h[1];
	h[4] = h[0];
	h[3] = x3;
	h[2] = x2;
	h[1] = x1;
	h[0] = x0;

	return (struct IQ){ I, Q };
}

static void rf_rx(void)
{
	const uint32_t base = (uint32_t)rx_cos;
	int pos = 0;

	while (true) {
		if (multicore_fifo_rvalid()) {
			multicore_fifo_pop_blocking();
			multicore_fifo_push_blocking(0);
			return;
		}

		int head = (dma_hw->ch[dma_ch_in_cos].write_addr - base) / 4;
		int delta = (head < pos ? head + RX_WORDS : head) - pos;

		sleep_us(10);

		while (delta < RX_STRIDE) {
			sleep_us(1);
			head = (dma_hw->ch[dma_ch_in_cos].write_addr - base) / 4;
			delta = (head < pos ? head + RX_WORDS : head) - pos;
		}

		const uint32_t *cos_ptr = rx_cos + pos;

		pos = (pos + RX_STRIDE) & (RX_WORDS - 1);

		uint8_t *block = iq_queue_buffer[iq_queue_pos];
		uint8_t *blockptr = block;

		/*
		 * Since every 2 samples add to either +1 or -1,
		 * the maximum amplitude in one direction is 1/2.
		 *
		 * We are allowing the counters to only go as high
		 * as sampling rate.
		 */
		int64_t max_amplitude = CLK_SYS_HZ / 2 / sample_rate;

		for (int i = 0; i < IQ_SAMPLES; i++) {
			struct IQ IQ = next_sample(cos_ptr);
			int64_t I = IQ.I;
			int64_t Q = IQ.Q;
			cos_ptr += 2 * DECIMATE;

			I -= (max_amplitude * 181) / 256;
			I /= max_amplitude;

			if (I > 127)
				I = 127;
			else if (I < -128)
				I = -128;

			*blockptr++ = (uint8_t)I + 128;

			Q -= (max_amplitude * 181) / 256;
			Q /= max_amplitude;

			if (Q > 127)
				Q = 127;
			else if (Q < -128)
				Q = -128;

			*blockptr++ = (uint8_t)Q + 128;
		}

		if (queue_try_add(&iq_queue, &block)) {
			iq_queue_pos = (iq_queue_pos + 1) & (IQ_QUEUE_LEN - 1);
		}
	}
}

static void run_command(uint8_t cmd, uint32_t arg)
{
	if (0x01 == cmd) {
		/* Tune to a new center frequency */
		rx_lo_init(arg + sample_rate, true);
	} else if (0x02 == cmd) {
		/* Set the rate at which IQ sample pairs are sent */
		sample_rate = arg;
		dma_timer_set_fraction(dma_t_samp, 1, CLK_SYS_HZ / (sample_rate * DECIMATE));
		rx_lo_init(arg + sample_rate, true);
	} else if (0x04 == cmd) {
		/* Set the tuner gain level */
		gain = INIT_GAIN * powf(10.0f, arg / 200.0f);
	} else if (0x0d == cmd) {
		/* Set tuner gain by the tuner's gain index */

		if (arg >= NUM_GAINS)
			arg = NUM_GAINS - 1;

		gain = INIT_GAIN * powf(10.0f, gains[arg] / 200.0f);
	}
}

static int check_command(void)
{
	static uint8_t buf[5];
	static int pos = 0;

	int c;

	while ((c = getchar_timeout_us(0)) >= 0) {
		if (0 == pos && 0 == c)
			return 0;

		buf[pos++] = c;

		if (5 == pos) {
			uint32_t arg = (buf[1] << 24) | (buf[2] << 16) | (buf[3] << 8) | buf[4];
			run_command(buf[0], arg);
			pos = 0;
			return buf[0];
		}
	}

	return -1;
}

static void do_rx()
{
	rf_rx_start();
	sleep_us(100);

	dma_ch_in_cos = dma_claim_unused_channel(true);

	dma_channel_config dma_conf;

	dma_conf = dma_channel_get_default_config(dma_ch_in_cos);
	channel_config_set_transfer_data_size(&dma_conf, DMA_SIZE_32);
	channel_config_set_read_increment(&dma_conf, false);
	channel_config_set_write_increment(&dma_conf, true);
	channel_config_set_ring(&dma_conf, GPIO_OUT, RX_BITS_DEPTH);
	channel_config_set_dreq(&dma_conf, pio_get_dreq(PIO, AD_SM, false));
	dma_channel_configure(dma_ch_in_cos, &dma_conf, rx_cos, &PIO->rxf[AD_SM], UINT_MAX, true);

	multicore_launch_core1(rf_rx);

	const uint8_t *block;

	while (queue_try_remove(&iq_queue, &block))
		/* Flush the queue */;

	while (true) {
		int cmd;

		while ((cmd = check_command()) >= 0)
			if (0 == cmd)
				goto done;

		if (queue_try_remove(&iq_queue, &block)) {
			fwrite(block, IQ_BLOCK_LEN, 1, stdout);
			fflush(stdout);
		} else {
			sleep_us(25);
		}
	}

done:
	multicore_fifo_push_blocking(0);
	multicore_fifo_pop_blocking();
	sleep_us(10);
	multicore_reset_core1();

	rf_rx_stop();

	dma_channel_clear_chain_to(dma_ch_in_cos);
	dma_channel_abort(dma_ch_in_cos);
	dma_channel_cleanup(dma_ch_in_cos);
	dma_channel_unclaim(dma_ch_in_cos);
	dma_ch_in_cos = -1;
}

int main()
{
	vreg_set_voltage(VREG_VOLTAGE);
	set_sys_clock_khz(CLK_SYS_HZ / KHZ, true);
	clock_configure(clk_peri, 0, CLOCKS_CLK_PERI_CTRL_AUXSRC_VALUE_CLKSRC_PLL_SYS, CLK_SYS_HZ,
			CLK_SYS_HZ);

	/* Enable PSU PWM mode. */
	gpio_init(PSU_PIN);
	gpio_set_dir(PSU_PIN, GPIO_OUT);
	gpio_put(PSU_PIN, 1);

	bus_ctrl_hw->priority |= BUSCTRL_BUS_PRIORITY_DMA_W_BITS | BUSCTRL_BUS_PRIORITY_DMA_R_BITS;

	stdio_usb_init();
	setvbuf(stdout, NULL, _IONBF, 0);

	queue_init(&iq_queue, sizeof(uint8_t *), IQ_QUEUE_LEN);

	rx_lo_init(INIT_FREQ - INIT_SAMPLE_RATE, true);

	while (true) {
		if (check_command() > 0) {
			static const uint32_t header[3] = { __builtin_bswap32(0x52544c30),
							    __builtin_bswap32(5),
							    __builtin_bswap32(NUM_GAINS) };
			fwrite(header, sizeof header, 1, stdout);
			fflush(stdout);

			do_rx();
		}

		sleep_ms(10);
	}
}