import { atlasConnectQuery } from "@augmedi/proto-gen";
import { assertDefined, isTruthy, unreachable } from "@augmedi/type-utils";
import { useQuery } from "@connectrpc/connect-query";
import {
  Alert,
  Button,
  Divider,
  Grid,
  GridCol,
  Input,
  InputWrapper,
  List,
  Loader,
  Select,
  Stack,
  Text,
  Title,
} from "@mantine/core";
import { useSuspenseQuery as tsUseSuspenseQuery } from "@tanstack/react-query";
import assert from "assert";
import { isEqual, sortBy } from "lodash-es";
import { Suspense, useEffect, useMemo, useRef, useState } from "react";
import { useParams, useSearch } from "wouter";
import { navigate } from "wouter/use-browser-location";
import z from "zod";
import { annotatorLeftModalProps } from "../../logic/annotator-left-modal";
import { AppLayout, useAppLayout } from "../../logic/app-layout";
import { useNavigationLockAndBeforeUnload } from "../../logic/navigation-lock";
import { showErrorNotification } from "../../logic/notification";
import {
  getTreeAnatomyUrl,
  parseTreeAnatomyJson,
  type ComponentLabel,
  type LabelSet,
} from "../../logic/tree-anatomy-labelling";
import ModelViewerUi from "../ModelViewerUi";
import { OurCanvas } from "../OurCanvas";
import {
  SharedModelPreviewStuff,
  type SharedModelPreviewStuffRef,
} from "../SharedModelPreviewStuff";
import { StructureSelect } from "../StructureSelect";
import { TreeAnatomyCanvasContent } from "./TreeAnatomyCanvasContent";

const componentZod = z.object({
  i_component: z.number(),
  radius: z.number(),
  distance_along_loops: z.number(),
  adjacent_components: z.array(z.number()),
  estimated_parents: z.array(z.number()),
});
export type Component = z.infer<typeof componentZod>;

const searchParamsZod = z.object({
  initialBlobUrl: z.string().optional(),
});

type InMemoryComponentLabel = Pick<ComponentLabel, "directStructureId">;

enum ColorMode {
  SameAsSelection = "SameAsSelection",
  DirectParents = "DirectParents",
  DirectChildren = "DirectChildren",
  TransitiveParents = "TransitiveParents",
  TransitiveChildren = "TransitiveChildren",
  RandomMesh = "RandomMesh",
  RandomStructureDirect = "RandomStructureDirect",
  RandomStructureInherited = "RandomStructureInherited",
  HasStructure = "HasStructure",
}

const allColorModes = Object.values(ColorMode);

const descriptionsByColorMode: { [key in ColorMode]: string[] } = {
  SameAsSelection: [
    "Blue: Selected mesh",
    "Green: Meshes with the same direct and inherited structure as the inherited structure of the selected mesh",
    "Yellow: Meshes with the same inherited structure as the inherited structure of the selected mesh",
  ],
  DirectParents: [
    "Blue: Selected mesh",
    "Green: Primary parent of the selected mesh",
    "Yellow: Other direct parents of the selected mesh",
  ],
  DirectChildren: [
    "Blue: Selected mesh",
    "Green: Direct children of the selected mesh",
  ],
  TransitiveParents: [
    "Blue: Selected mesh",
    "Green: Parents of the selected mesh, including over multiple steps",
  ],
  TransitiveChildren: [
    "Blue: Selected mesh",
    "Green: Children of the selected mesh, including over multiple steps",
  ],
  RandomMesh: [
    "Black: Selected mesh",
    "Random colors: Each mesh has a unique color",
  ],
  RandomStructureDirect: [
    "Black: Selected mesh",
    "Random colors: Each structure has a unique color (meshes with the same direct structure have the same color)",
  ],
  RandomStructureInherited: [
    "Black: Selected mesh",
    "Random colors: Each structure has a unique color (meshes with the same inherited structure have the same color)",
  ],
  HasStructure: [
    "Blue: Selected mesh",
    "Green: Meshes with a structure directly assigned",
    "Yellow: Meshes with a structure inherited from a transitive parent",
  ],
};

function calculateInheritedStructureIds(
  directStructureIds: (string | undefined)[],
  parentsByComponent: number[][],
): (string | undefined)[] {
  const inheritedStructureIds = [...directStructureIds];
  const visited = directStructureIds.map((id) => !!id);

  function fillInheritedStructureId(index: number): string | undefined {
    if (visited[index]) {
      return;
    }
    visited[index] = true;
    if (parentsByComponent[index].length) {
      const mainParent = parentsByComponent[index][0];
      fillInheritedStructureId(mainParent);
      inheritedStructureIds[index] = inheritedStructureIds[mainParent];
    }
  }

  for (let i = 0; i < inheritedStructureIds.length; i++) {
    fillInheritedStructureId(i);
  }

  return inheritedStructureIds;
}

function calculateComponentColors(
  components: Component[],
  directStructureIdsByComponent: (string | undefined)[],
  inheritedStructureIdsByComponent: (string | undefined)[],
  selectedComponentIndex: number | undefined,
  colorMode: ColorMode,
): string[] {
  const colors = components.map(() => "gray");

  const parentsByComponent = components.map((c) => c.estimated_parents);
  const childrenByComponent: number[][] = components.map(() => []);
  for (const component of components) {
    for (const parentIndex of component.estimated_parents) {
      childrenByComponent[parentIndex].push(component.i_component);
    }
  }

  const getTransitiveRelations = (
    startIndex: number,
    relativesByComponent: number[][], // either parents or children
  ): Set<number> => {
    const visited = new Set<number>();
    const stack = [startIndex];
    visited.add(startIndex);

    while (stack.length > 0) {
      const current = stack.pop()!;
      const relatives = relativesByComponent[current];
      for (const relative of relatives) {
        if (!visited.has(relative)) {
          visited.add(relative);
          stack.push(relative);
        }
      }
    }
    visited.delete(startIndex);
    return visited;
  };

  const getColorForIndex = (index: number): string => {
    const hue = (index * 137.508) % 360;
    return `hsl(${hue}, 100%, 50%)`;
  };

  switch (colorMode) {
    case ColorMode.SameAsSelection: {
      if (selectedComponentIndex === undefined) {
        break;
      }
      const selectedStructureId =
        inheritedStructureIdsByComponent[selectedComponentIndex];
      for (const component of components) {
        const i = component.i_component;
        if (i === selectedComponentIndex) {
          colors[i] = "blue";
        } else if (
          inheritedStructureIdsByComponent[i] === selectedStructureId &&
          directStructureIdsByComponent[i] === selectedStructureId
        ) {
          colors[i] = "green";
        } else if (
          inheritedStructureIdsByComponent[i] === selectedStructureId
        ) {
          colors[i] = "yellow";
        }
      }
      break;
    }

    case ColorMode.DirectParents: {
      if (selectedComponentIndex === undefined) {
        break;
      }
      const selectedComponent = components[selectedComponentIndex];
      const directParents = selectedComponent.estimated_parents || [];
      const primaryParentIndex = directParents[0];
      for (const component of components) {
        const i = component.i_component;
        if (i === selectedComponentIndex) {
          colors[i] = "blue";
        } else if (i === primaryParentIndex) {
          colors[i] = "green";
        } else if (directParents.includes(i)) {
          colors[i] = "yellow";
        }
      }
      break;
    }

    case ColorMode.DirectChildren: {
      if (selectedComponentIndex === undefined) {
        break;
      }
      const directChildren = childrenByComponent[selectedComponentIndex];
      for (const component of components) {
        const i = component.i_component;
        if (i === selectedComponentIndex) {
          colors[i] = "blue";
        } else if (directChildren.includes(i)) {
          colors[i] = "green";
        }
      }
      break;
    }

    case ColorMode.TransitiveParents: {
      if (selectedComponentIndex === undefined) {
        break;
      }
      const transitiveParents = getTransitiveRelations(
        selectedComponentIndex,
        parentsByComponent,
      );
      for (const component of components) {
        const i = component.i_component;
        if (i === selectedComponentIndex) {
          colors[i] = "blue";
        } else if (transitiveParents.has(i)) {
          colors[i] = "green";
        }
      }
      break;
    }

    case ColorMode.TransitiveChildren: {
      if (selectedComponentIndex === undefined) {
        break;
      }
      const transitiveChildren = getTransitiveRelations(
        selectedComponentIndex,
        childrenByComponent,
      );
      for (const component of components) {
        const i = component.i_component;
        if (i === selectedComponentIndex) {
          colors[i] = "blue";
        } else if (transitiveChildren.has(i)) {
          colors[i] = "green";
        }
      }
      break;
    }

    case ColorMode.RandomMesh: {
      for (const component of components) {
        const i = component.i_component;
        if (i === selectedComponentIndex) {
          colors[i] = "black";
        } else {
          colors[i] = getColorForIndex(i);
        }
      }
      break;
    }

    case ColorMode.RandomStructureDirect:
    case ColorMode.RandomStructureInherited: {
      const relevantStructureIdsByComponent =
        colorMode === ColorMode.RandomStructureDirect
          ? directStructureIdsByComponent
          : inheritedStructureIdsByComponent;
      const relevantStructureIdsSet = new Set(
        relevantStructureIdsByComponent.filter(isTruthy),
      );

      const colorsByStructure = new Map<string, string>();
      for (const structureId of sortBy([...relevantStructureIdsSet])) {
        colorsByStructure.set(
          structureId,
          getColorForIndex(colorsByStructure.size),
        );
      }

      for (const component of components) {
        const i = component.i_component;
        const structureId = relevantStructureIdsByComponent[i];
        if (i === selectedComponentIndex) {
          colors[i] = "black";
        } else if (structureId) {
          colors[i] = assertDefined(colorsByStructure.get(structureId));
        }
      }
      break;
    }

    case ColorMode.HasStructure: {
      for (const component of components) {
        const i = component.i_component;
        if (i === selectedComponentIndex) {
          colors[i] = "blue";
        } else if (directStructureIdsByComponent[i]) {
          colors[i] = "green";
        } else if (inheritedStructureIdsByComponent[i]) {
          colors[i] = "yellow";
        }
      }
      break;
    }

    default:
      return unreachable(colorMode);
  }

  return colors;
}

export const TreeAnatomyLabellingPage = () => {
  const { projectName } = useParams<{ projectName: string }>();
  const searchString = useSearch();
  const { initialBlobUrl } = useMemo(
    () =>
      searchParamsZod.parse(
        Object.fromEntries(new URLSearchParams(searchString).entries()),
      ),
    [searchString],
  );

  useAppLayout(AppLayout.FullscreenWithHeader);

  const sharedStuffRef = useRef<SharedModelPreviewStuffRef>(null);

  const componentsJsonUrl = getTreeAnatomyUrl(`${projectName}/components.json`);
  const componentsQuery = tsUseSuspenseQuery({
    queryKey: ["tree-anatomy-components-json", componentsJsonUrl],
    queryFn: async () => {
      const res = await fetch(componentsJsonUrl);
      const json = await res.json();
      const parsedRes = z
        .object({ components: z.array(componentZod) })
        .parse(json);
      for (const [i, component] of parsedRes["components"].entries()) {
        assert(component["i_component"] === i);
      }
      return parsedRes["components"];
    },
  });
  const components = componentsQuery.data;

  const [labels, setLabels] = useState<Map<number, InMemoryComponentLabel>>(
    new Map(),
  );
  useEffect(() => {
    async function loadFromBlobUrl(blobUrl: string) {
      const res = await fetch(blobUrl);
      const json = await res.json();
      const parsed = parseTreeAnatomyJson(json);
      assert(parsed.projectName === projectName);

      const newLabels = new Map<number, InMemoryComponentLabel>();
      for (const [iComponent, label] of parsed.labels.entries()) {
        newLabels.set(iComponent, {
          directStructureId: label.directStructureId,
        });
      }
      setLabels(newLabels);
    }

    if (!initialBlobUrl) {
      return;
    }
    loadFromBlobUrl(initialBlobUrl).catch((err) => {
      console.error(err);
      showErrorNotification({
        message: "Failed to load labels. See console for more info.",
      });
      navigate("/label-tree-anatomy");
    });
  }, [initialBlobUrl]);

  const [_selectedComponentIndex, setSelectedComponentIndex] =
    useState<number>();
  const selectedComponentIndex =
    _selectedComponentIndex !== undefined &&
    _selectedComponentIndex < components.length
      ? _selectedComponentIndex
      : undefined;

  const [colorMode, setColorMode] = useState(
    ColorMode.RandomStructureInherited,
  );

  const updateSelectedLabel = (
    cb: (label: InMemoryComponentLabel) => InMemoryComponentLabel,
  ) => {
    if (selectedComponentIndex === undefined) {
      return;
    }
    setLabels((labels) => {
      const newLabels = new Map(labels);
      const oldLabel: InMemoryComponentLabel =
        labels.get(selectedComponentIndex) ?? {};
      newLabels.set(selectedComponentIndex, cb(oldLabel));
      return newLabels;
    });
  };

  const { directStructureIdsByComponent, inheritedStructureIdsByComponent } =
    useMemo(() => {
      const directStructureIdsByComponent = components.map(
        (c) => labels.get(c.i_component)?.directStructureId,
      );
      const parentsByComponent = components.map((c) => c.estimated_parents);
      const inheritedStructureIdsByComponent = calculateInheritedStructureIds(
        directStructureIdsByComponent,
        parentsByComponent,
      );
      return {
        directStructureIdsByComponent,
        inheritedStructureIdsByComponent,
      };
    }, [components, labels]);

  const componentColors = useMemo(
    () =>
      calculateComponentColors(
        components,
        directStructureIdsByComponent,
        inheritedStructureIdsByComponent,
        selectedComponentIndex,
        colorMode,
      ),
    [
      components,
      directStructureIdsByComponent,
      inheritedStructureIdsByComponent,
      selectedComponentIndex,
      colorMode,
    ],
  );

  const selectedInheritedStructureId =
    selectedComponentIndex !== undefined
      ? inheritedStructureIdsByComponent[selectedComponentIndex]
      : undefined;
  const selectedInheritedStructureMetadataQuery = useQuery(
    atlasConnectQuery.getStructureMetadata,
    { id: selectedInheritedStructureId },
    { enabled: !!selectedInheritedStructureId },
  );

  const nextVersion = useMemo((): LabelSet => {
    assert(
      directStructureIdsByComponent.length ===
        inheritedStructureIdsByComponent.length,
    );
    return {
      projectName,
      labels: directStructureIdsByComponent.map((directStructureId, i) => ({
        directStructureId,
        indirectStructureId: inheritedStructureIdsByComponent[i],
      })),
    };
  }, [labels, directStructureIdsByComponent, projectName]);
  const [lastSavedVersion, setLastSavedVersion] = useState<LabelSet>();
  const canSave = useMemo(
    () => !isEqual(nextVersion, lastSavedVersion),
    [nextVersion, lastSavedVersion],
  );
  const save = () => {
    const documentContent = JSON.stringify(nextVersion, null, 2);
    const blob = new Blob([documentContent], { type: "application/json" });
    const url = URL.createObjectURL(blob);
    const a = document.createElement("a");
    a.href = url;
    a.download = `${projectName}-labels-${new Date().toISOString()}.json`;
    a.click();

    setLastSavedVersion(nextVersion);
  };
  useNavigationLockAndBeforeUnload(canSave);

  return (
    <Grid h="100%" styles={{ inner: { height: "100%" } }} gutter={0}>
      <GridCol span={3} h="100%" p="md" style={{ overflow: "auto" }}>
        <Stack>
          <Title order={5}>Project</Title>
          <Button onClick={save} disabled={!canSave}>
            {canSave ? "Save to file" : "All changes saved"}
          </Button>
          <Select
            label="Color mode"
            data={allColorModes}
            value={colorMode}
            onChange={(colorMode) =>
              colorMode && setColorMode(colorMode as ColorMode)
            }
          />
          <Alert color="gray">
            <List>
              {descriptionsByColorMode[colorMode].map((description, i) => (
                <List.Item key={i}>{description}</List.Item>
              ))}
            </List>
          </Alert>
          {selectedComponentIndex !== undefined && (
            <>
              <Divider />
              <Title order={5}>Selected mesh</Title>
              <InputWrapper label="Directly assigned structure">
                <StructureSelect
                  structureId={
                    directStructureIdsByComponent[selectedComponentIndex]
                  }
                  onStructureIdChange={(structureId) =>
                    updateSelectedLabel((label) => ({
                      ...label,
                      directStructureId: structureId,
                    }))
                  }
                  spotlightRootProps={annotatorLeftModalProps}
                />
              </InputWrapper>
              <div>
                <Input.Label>Inherited structure</Input.Label>
                <Text>
                  {selectedInheritedStructureMetadataQuery.data?.base?.name}
                </Text>
              </div>
            </>
          )}
        </Stack>
      </GridCol>
      <GridCol
        span={9}
        style={{
          borderLeft: "1px solid #dee2e6",
          borderRight: "1px solid #dee2e6",
        }}
      >
        <Suspense fallback={<Loader />}>
          <ModelViewerUi
            onResetCamera={() => sharedStuffRef.current?.resetCamera()}
          >
            <OurCanvas>
              <SharedModelPreviewStuff
                ref={sharedStuffRef}
                backgroundColor={0xffffff}
              />
              <TreeAnatomyCanvasContent
                componentSettings={components.map((component, i) => ({
                  index: i,
                  url: getTreeAnatomyUrl(
                    `${projectName}/component_${component.i_component}.obj`,
                  ),
                  color: componentColors[i],
                }))}
                onSelectComponent={(settings) =>
                  setSelectedComponentIndex(settings.index)
                }
              />
            </OurCanvas>
          </ModelViewerUi>
        </Suspense>
      </GridCol>
    </Grid>
  );
};
