import {
  useQuery,
  UseQueryOptions,
  QueryKey,
  UseQueryResult,
} from '@tanstack/react-query';
import { useState } from 'react';
import completion, { Prompt, RequestPayload } from '../client/backendClient';

type UseLazyQueryOptions<
  TQueryFnData,
  TError,
  TData,
  TQueryKey extends QueryKey
> = Omit<
  UseQueryOptions<TQueryFnData, TError, TData, TQueryKey>,
  'initialData'
> & { initialData?: () => undefined };

// Wrapper for a query to only fire off a query once a trigger function has been called
export const useLazyQuery = <
  TQueryFnData = unknown,
  TError = unknown,
  TData = TQueryFnData,
  TQueryKey extends QueryKey = QueryKey
>(
  options: UseLazyQueryOptions<TQueryFnData, TError, TData, TQueryKey>
): UseQueryResult<TData, TError> & { trigger: () => void } => {
  const [triggered, setTriggered] = useState(false);
  return {
    trigger: () => setTriggered(true),
    ...useQuery({ ...options, enabled: triggered }),
  };
};

export const useGetCompletionLazy = (payload: RequestPayload) => {
  const chat = payload.context?.chat;
  const lastEntry = chat?.length
    ? chat[chat.length - 1]
    : { prompt: { value: 'initial-prompt' }, userInput: 'initial-input' };
  const { error, data, isInitialLoading, trigger } = useLazyQuery({
    queryKey: [
      'completion',
      lastEntry.prompt.value,
      lastEntry.userInput,
      payload.context?.stepIndex,
    ],
    queryFn: () => completion(payload),
    cacheTime: 0,
    refetchOnWindowFocus: false,
  });
  return {
    trigger,
    isInitialLoading,
    error,
    data,
  };
};

export const useCompletion = (
  type: string,
  sessionId: string,
  prompt?: Prompt
) => {
  // The API shouldn't get called with this payload, it'll be updated before we trigger
  const [payload, setPayload] = useState<RequestPayload>({
    context: {
      stepIndex: -1,
      chat: [],
    },
    type,
    sessionId,
  });

  const { trigger, data, error, isInitialLoading } =
    useGetCompletionLazy(payload);

  return {
    trigger: (text: string | string[]) => {
      if (!prompt) {
        throw new Error(
          "We shouldn't have progressed to the next step without a prompt to submit!"
        );
      }

      setPayload((p) => ({
        context: {
          stepIndex: (p.context?.stepIndex || 0) + 1,
          chat: [
            ...(p.context?.chat || []),
            {
              prompt,
              userInput: text,
            },
          ],
        },
        type,
        sessionId,
      }));

      trigger();
    },
    data,
    error,
    isInitialLoading,
  };
};
