#include <cstdlib>
#include "list.h"
#include <cassert>
#include <cstdio>

List::List()
	:first(NULL), last(NULL), count(0)
{
}
List::Iterator& List::Iterator::operator++()
{
	previousElement = currentElement;
	currentElement = currentElement->next;
	return *this;
}

List::Iterator List::Iterator::operator++(int unused)
{
	List::Iterator result = *this;
	++(*this);
	return result;
}
List::Iterator List::Remove(List::Iterator position)
{
	if(position.currentElement == NULL)
		return position;

	Iterator iteratorToNext = position;
	if (position.previousElement == NULL) {
		first = position.currentElement->next;
	}
	else {
		position.previousElement -> next = position.currentElement -> next; 
	}

	if (position.currentElement->next == NULL) {
		last = position.previousElement;
	}
	iteratorToNext.currentElement = position.currentElement -> next;

	delete position.currentElement;
	count--;
	
	return iteratorToNext;

}
List::~List()
{
	while(Remove(Begin())!=End()) {}
}

int* List::Iterator::operator*()
{
	return &(currentElement->value);
}

List::Iterator List::Iterator::Next()
{
	return Iterator(currentElement, currentElement->next);
}

List::Iterator List::Begin()
{
	return List::Iterator(NULL, first);
}

List::Iterator List::End()
{
	return List::Iterator(last, NULL);
}

List::Iterator List::Insert(List::Iterator position, int value)
{
	ListElement* nowy = new ListElement;	
	nowy->value = value;
	Iterator it = Iterator(position.previousElement, nowy);

	if (position.previousElement == NULL) {
		first = nowy;
	}
	else {
		position.previousElement -> next = nowy;		
	}


	if (position.currentElement ==NULL) {
		nowy -> next = NULL;
		last = nowy;
	}
	else {
		nowy -> next = position.currentElement;
	}
	position.previousElement = nowy;
	
	count++;	
	return it;
}

int* List::ElementAt(List::Iterator position)
{
	return &(position.currentElement->value);
}


int List::Count()
{
	return count;
}


void test1()
{
	List list;

	for(int i = 1; i < 5; i++)
	{
		list.Insert(list.End(), i);
	}

	assert(list.Count() == 4);

	List::Iterator first = list.Begin();
	List::Iterator second(first);
	second++;

	assert(*(list.ElementAt(first)) == 1);
	assert(*(list.ElementAt(second)) == 2);
}

void test2()
{
	List list;

	for(int i = 0; i < 4; i++)
	{
		list.Insert(list.End(), i);
	}

	int iCheck = 0;
	for(List::Iterator it = list.Begin(); it != list.End(); ++it)
	{
		int* valuePtr = *it;

		assert(iCheck == *valuePtr);

		iCheck++;
	}
}

void test3()
{
	List list;

	for(int i = 0; i < 20; i++)
	{
		list.Insert(list.End(), i);
	}

	List::Iterator it = list.Begin();

	int iCheck = 0;
	while(list.Count() > 0)
	{
		assert(iCheck == *(*it));
		it = list.Remove(it);
		iCheck++;	
	}
}

int main()
{
	test1();
	test2();
	test3();

	return 0;
}
