import { CdkVirtualScrollViewport, VirtualScrollStrategy } from "@angular/cdk/scrolling";
import { Subject, distinctUntilChanged } from "rxjs";
import { ListItem } from "./multiselect-list.component";

// Amount of items before and after visible viewport
const paddingAbove = 5;
const paddingBelow = 5;

export class MultiselectVirtualScrollStrategy implements VirtualScrollStrategy {
    items: ListItem[] = [];
    private totalHeight: number = 0;
    private viewport: CdkVirtualScrollViewport | null = null;
    private wrapper: ChildNode | null = null;

    private heightCache = new Map<string, number>(); // Cache of the actual rows height (clientHeight of DOM element)

    constructor(private idField: string, private averageRowHeight: number) {}

    private scrolledIndexChange$ = new Subject<number>();
    scrolledIndexChange = this.scrolledIndexChange$.pipe(distinctUntilChanged());

    onContentScrolled(): void {
        this.updateRenderedRange();
    }

    onDataLengthChanged(): void {
        this.viewport?.setTotalContentSize(this.getTotalHeight());
        this.updateRenderedRange();
    }

    onContentRendered(): void {
        // Don't need this method
    }

    onRenderedOffsetChanged(): void {
        // Don't need this method
    }

    updateItems(newItems: ListItem[]): void {
        this.items = newItems;
    }

    scrollToIndex(index: number, behavior: ScrollBehavior): void {
        if (!this.viewport) return;

        const offset = this.getOffsetByItemIndex(index);
        this.viewport.scrollToOffset(offset, behavior);
    }

    attach(viewport: CdkVirtualScrollViewport): void {
        this.viewport = viewport;
        this.wrapper = viewport.getElementRef().nativeElement.childNodes[0];
    }

    detach(): void {
        this.viewport = null;
    }

    private getItemHeight(item: ListItem): number {
        return this.heightCache.get(item.data[this.idField] as string) ?? this.averageRowHeight;
    }

    private measureItemsHeight(items: ListItem[]): number {
        return items.map(item => this.getItemHeight(item)).reduce((a, c) => a + c, 0);
    }

    private getTotalHeight(): number {
        this.totalHeight = this.measureItemsHeight(this.items);
        return this.totalHeight;
    }

    private getOffsetByItemIndex(index: number): number {
        return this.measureItemsHeight(this.items.slice(0, index));
    }

    private getItemIndexByOffset(offset: number): number {
        let offsetSum = 0;

        for (let i = 0; i < this.items.length; i++) {
            offsetSum += this.getItemHeight(this.items[i]);

            if (offsetSum >= offset) return i;
        }

        return 0;
    }

    private determineItemsCountInViewport(startIndex: number): number {
        if (!this.viewport) return 0;

        let totalSize = 0;
        const viewportSize = this.viewport.getViewportSize();

        for (let i = startIndex; i < this.items.length; i++) {
            totalSize += this.getItemHeight(this.items[i]);

            if (totalSize >= viewportSize) return i - startIndex + 1;
        }

        return this.items.length - startIndex;
    }

    private updateRenderedRange() {
        if (!this.viewport) return;

        const scrollOffset = this.viewport.measureScrollOffset();
        const scrollIndex = this.getItemIndexByOffset(scrollOffset);
        const dataLength = this.viewport.getDataLength();

        const start = Math.max(0, scrollIndex - paddingAbove);
        const end = Math.min(
            dataLength,
            scrollIndex + this.determineItemsCountInViewport(scrollIndex) + paddingBelow
        );

        this.viewport.setRenderedRange({ start, end });
        this.viewport.setRenderedContentOffset(
            this.getOffsetByItemIndex(start)
        );
        this.scrolledIndexChange$.next(scrollIndex);

        this.updateHeightCache();
    }

    private updateHeightCache() {
        if (!this.wrapper || !this.viewport) return;

        let totalItemsHeightChanges = 0;

        const nodes = this.wrapper.childNodes;
        for (let i = 0; i < nodes.length; i++) {
            const node = nodes[i] as HTMLElement;
            if (!node || node.nodeName !== "TR") continue;

            const id = node.getAttribute("data-item-id");
            if (this.heightCache.has(id)) continue;

            // getBoundingClientRect().height is more accurate, than clientHeight
            const actualHeight = node.getBoundingClientRect().height;
            this.heightCache.set(id, actualHeight);
            totalItemsHeightChanges += actualHeight - this.averageRowHeight;
        }

        if (totalItemsHeightChanges) {
            this.totalHeight += totalItemsHeightChanges;
            this.viewport.setTotalContentSize(this.totalHeight);
        }
    }
}
