package anyreader

import (
	"git.milar.in/milarin/ds"
	"git.milar.in/milarin/slices"
)

type SafeReader[T any] struct {
	buf     ds.Stack[T]
	indices ds.Stack[uint64]
	index   uint64
	src     func() T
}

func NewSafeReaderFromSlice[T any](s []T) *SafeReader[T] {
	return NewSafeReaderFromFunc(sliceToSafeFunc(s))
}

func NewSafeReaderFromFunc[T any](src func() T) *SafeReader[T] {
	return &SafeReader[T]{
		src:     src,
		buf:     ds.NewArrayStack[T](),
		index:   0,
		indices: ds.NewArrayStack[uint64](),
	}
}

func (r *SafeReader[T]) Read() T {
	v := r.src()
	r.buf.Push(v)
	r.index++
	return v
}

func (r *SafeReader[T]) Unread() {
	if r.buf.Empty() {
		return
	}

	v := r.buf.Pop()
	r.index--

	returned := false
	oldSrc := r.src
	r.src = func() T {
		if returned {
			return oldSrc()
		}

		returned = true
		return v
	}
}

func (r *SafeReader[T]) UnreadN(n int) {
	for i := 0; i < n; i++ {
		r.Unread()
	}
}

func (r *SafeReader[T]) Peek() T {
	value := r.Read()
	r.Unread()
	return value
}

func (r *SafeReader[T]) ReadWhile(f ...func(T) bool) []T {
	res := make([]T, 0, 10)
	for value := r.Read(); findFirstTrue(value, f); value = r.Read() {
		res = append(res, value)
	}
	return res
}

func (r *SafeReader[T]) ReadUntil(f ...func(T) bool) []T {
	return r.ReadWhile(func(v T) bool { return !findFirstTrue(v, f) })
}

func (r *SafeReader[T]) SkipUntil(f ...func(T) bool) {
	r.ReadUntil(f...)
	r.Unread()
}

func (r *SafeReader[T]) SkipWhile(f ...func(T) bool) {
	r.ReadWhile(f...)
	r.Unread()
}

func (r *SafeReader[T]) Expect(f ...func(T) bool) bool {
	return findFirstTrue(r.Read(), f)
}

func (r *SafeReader[T]) Push() {
	r.indices.Push(r.index)
}

func (r *SafeReader[T]) Pop() []T {
	if r.indices.Empty() {
		return []T{}
	}

	lastIndex := r.indices.Pop()
	currentIndex := r.index
	if lastIndex < currentIndex {
		values := make([]T, 0, int(currentIndex-lastIndex))
		for i := 0; i < int(currentIndex-lastIndex); i++ {
			r.Unread()
			values = append(values, r.Peek())
		}
		return slices.Reverse(values)
	} else if lastIndex > currentIndex {
		values := make([]T, 0, int(lastIndex-currentIndex))
		for i := 0; i < int(lastIndex-currentIndex); i++ {
			values = append(values, r.Read())
		}
		return values
	}

	return []T{}
}