import React, { useMemo, useEffect, useCallback } from "react";

interface TrapFocusProps {
  isActive: boolean;
  open: boolean;
  children: React.ReactNode;
}

export const TrapFocus = ({ open, isActive, children }: TrapFocusProps) => {
  const ignoreUtilFocusChanges = React.useRef<boolean | null>(null);
  const sentinelStart = React.useRef<HTMLDivElement | null>(null);
  const sentinelEnd = React.useRef<HTMLDivElement | null>(null);
  const returnFocusTo = React.useRef<Element | null>(null);
  const rootRef = React.useRef<HTMLDivElement | null>(null);
  const lastFocus = React.useRef<Element | null | undefined>();

  const attemptFocus = useCallback((element?: HTMLElement) => {
    if (!element || (element && !element.focus)) {
      return false;
    }

    ignoreUtilFocusChanges.current = true;
    try {
      element && element.focus();
    } catch (e) {}
    ignoreUtilFocusChanges.current = false;
    return document.activeElement === element;
  }, []);

  const focusFirstDescendant = useCallback(
    (element?: HTMLElement) => {
      if (!element) return;

      for (let i = 0; i < element.childNodes.length; i++) {
        let child: any = element.childNodes[i];
        if (attemptFocus(child) || focusFirstDescendant(child)) {
          return true;
        }
      }
      return false;
    },
    [attemptFocus]
  );

  const focusLastDescendant = useCallback(
    (element?: HTMLElement) => {
      if (!element) return;

      for (let i = element.childNodes.length - 1; i >= 0; i--) {
        let child: any = element.childNodes[i];
        if (attemptFocus(child) || focusLastDescendant(child)) {
          return true;
        }
      }
      return false;
    },
    [attemptFocus]
  );

  useMemo(() => {
    if (!open) {
      return;
    }

    returnFocusTo.current = document.activeElement;
  }, [open]);

  useEffect(() => {
    if (rootRef.current && !rootRef.current.contains(document.activeElement)) {
      rootRef.current.focus();
    }

    const trapFocus = (e: FocusEvent) => {
      if (ignoreUtilFocusChanges.current || !rootRef.current || !isActive) {
        return;
      }

      if (rootRef.current.contains(e.target as any)) {
        lastFocus.current = e.target as any;
      } else {
        focusFirstDescendant(rootRef.current);
        if (lastFocus.current === document.activeElement) {
          focusLastDescendant(rootRef.current);
        }
        lastFocus.current = document.activeElement;
      }
    };

    document.addEventListener("focus", trapFocus, true);

    return () => {
      document.removeEventListener("focus", trapFocus, true);

      const returnFocusNode = returnFocusTo.current as HTMLElement;

      if (returnFocusNode && returnFocusNode.focus && isActive) {
        returnFocusNode.focus();
      }
    };
  }, [focusFirstDescendant, focusLastDescendant, isActive, open, rootRef]);

  return (
    <>
      <div tabIndex={0} ref={sentinelStart} />
      {React.cloneElement(children as any, { ref: rootRef, tabIndex: -1 })}
      <div tabIndex={0} ref={sentinelEnd} />
    </>
  );
};
