import { nanoid } from 'nanoid';
import ReconnectingWebSocket from 'reconnecting-websocket';

import { getAuthToken } from '#/api/auth';
import { UnixTimeMicroseconds, UnixTimeMs } from '#/api/types';

import {
  userSignOutTopic,
  userUnauthorizedTopic,
} from '#/features/auth/auth-topics';
import { FeatureFlags } from '#/features/feature-flags';
import { trackUserSignOut } from '#/features/logging/datadog/user-session-tracking';
import { logEvent, logException } from '#/features/logging/logging';

import createTopic, { EventHandler } from '#/utils/createTopic';
import { usToMs } from '#/utils/date';
import { AbortError, isTimeoutError, TimeoutError } from '#/utils/errors';
import { Maybe } from '#/utils/types';

export interface ApiWebSocket<N = WSNotification> {
  readonly request: (reqParams: WSReqParams) => Promise<WSResponse<any>>; // eslint-disable-line @typescript-eslint/no-explicit-any
  /** @returns Unsubscribe function */
  readonly addNotificationHandler: (handler: EventHandler<N>) => () => void;
  readonly removeNotificationHandler: (handler: EventHandler<N>) => void;
  readonly readyState: ReconnectingWebSocket['readyState'];
  readonly close: ReconnectingWebSocket['close'];
  readonly reconnect: ReconnectingWebSocket['reconnect'];
  readonly addEventListener: ReconnectingWebSocket['addEventListener'];
  readonly removeEventListener: ReconnectingWebSocket['removeEventListener'];
}

export interface WSRequest {
  readonly jsonrpc: '2.0';
  readonly id: string | number;
  readonly method: string;
  readonly params?: { readonly [key: string]: unknown };
}

export interface WSNotification<Data = unknown> {
  readonly jsonrpc: '2.0';
  readonly method: string;
  readonly params: {
    readonly channel: string;
    readonly data: Data;
  };
}

interface RawWSResponse<T = unknown> {
  readonly jsonrpc: '2.0';
  readonly id: string | number;
  readonly result?: T;
  readonly error?: WSError;
  /** Server side latency */
  readonly usDiff?: number;
  /** Server side latency */
  readonly usIn?: number;
  /** Server side latency */
  readonly usOut?: number;
}

export interface WSResponse<T = unknown>
  extends Omit<RawWSResponse<T>, 'error'> {
  readonly error?: Error;
}

export interface WSError {
  readonly code?: number;
  readonly message?: string;
  readonly data?: {
    readonly method?: string;
    readonly timestamp?: number;
  };
}

export type AsyncWSResp<T> = Promise<WSResponse<T>>;

export interface WSReqParams {
  readonly method: string;
  readonly params?: Record<string, unknown>;
  readonly signal?: AbortSignal;
  /** Milliseconds, defaults to 10s. Use <= 0 to disable. */
  readonly timeout?: number;
}

type ReqCallback = (resp: RawWSResponse<unknown>) => void;

/**
 * Metadata used in logging.
 */
interface WsRequestMetadata {
  readonly logId: string;
  readonly id: string | number;
  readonly method: string;
  readonly channel: string;
  readonly enqueuedAt: number;
  readonly registeredAt: number | 'pending registration';
  readonly enqueuedToRegisteredDiffMs: number | 'pending registration';
}

/**
 * Metadata used in logging.
 */
interface WsResponseMetadata {
  readonly logId: string;
  readonly id: string | number;
  readonly method: string;
  readonly channel: string;
  readonly receivedAt: number;
  readonly sentToReceivedDiffMs: number | 'undefined';
}

/**
 * Metadata used in logging.
 */
type WsResponseServerSideMetadata = {
  usInMicroseconds: number | 'undefined';
  usOutMicroseconds: number | 'undefined';
  /** Measures time to process WS request on Cloud API */
  usDiffMs: number | 'undefined';
  /** Measures queueing time before Cloud API starts processing WS request */
  apiWaitTimeMs: number | 'undefined';
};

/**
 * Metadata used in logging.
 */
interface WsTimeoutMetadata {
  readonly logId: string;
  readonly id: string | number;
  readonly method: string;
  readonly channel: string;
  readonly timedOutAt: number;
  readonly sentToTimeoutDiff: number | 'undefined';
}

export interface CreateWebSocketOptions<
  RawNotification extends WSNotification,
  Notification,
> {
  readonly urlProvider: () => Promise<string>;
  readonly processNotification: (raw: RawNotification) => Notification | 'skip';
}

/**
 * Create an instance of our lightweight abstraction over a WebSocket.
 *
 * Key features are:
 * - Auto reconnection when the connection closes
 * - Auto reconnection when the connection fails to open
 * - Auto reconnection when a request times out (no response received)
 * - JSON-RPC requests
 * - Notification messages
 */
export default function createWebSocket<
  RawNotification extends WSNotification = WSNotification,
  Notification = WSNotification,
>(
  options: CreateWebSocketOptions<RawNotification, Notification>,
): ApiWebSocket<Notification> {
  const { urlProvider, processNotification } = options;

  const isLatencyLoggingEnabled = FeatureFlags.getBooleanValue(
    'logging-websocket-request-response-latencies-enabled',
    false,
  );
  const isTimeoutLatencyLoggingEnabled = FeatureFlags.getBooleanValue(
    'logging-websocket-timeout-latencies-enabled',
    false,
  );

  const ws =
    "false" === 'true'
      ? new ReconnectingWebSocket("wss://ws.api.testnet.paradex.trade/v1", [], {
          WebSocket: window.MockedWebSocket,
        })
      : new ReconnectingWebSocket(urlProvider);

  const notificationTopic = createTopic<Notification>('ws-notification');
  const getId = createCounter();
  const reqCallbacksMap = new Map<
    string | number,
    { callback: ReqCallback; metadata: WsRequestMetadata }
  >();

  async function request(reqParams: WSReqParams) {
    const { signal, timeout = 10_000, ...reqRest } = reqParams;
    return new Promise<WSResponse>((_resolve, _reject) => {
      const id = getId();

      const req: WSRequest = { id, jsonrpc: '2.0', ...reqRest };

      const enqueuedRequestMetadata: WsRequestMetadata = {
        logId: nanoid(),
        id: req.id,
        method: req.method,
        channel: String(req.params?.channel),
        enqueuedAt: Date.now(),
        registeredAt: 'pending registration',
        enqueuedToRegisteredDiffMs: 'pending registration',
      };
      ws.send(JSON.stringify(req));

      let timeoutHandle: ReturnType<typeof setTimeout> | undefined = undefined;

      const handleAbort = () => {
        reject(new AbortError('Web socket request aborted'));
      };

      const cleanUp = () => {
        reqCallbacksMap.delete(id);
        clearTimeout(timeoutHandle);
        signal?.removeEventListener('abort', handleAbort);
      };

      const resolve = (resp: RawWSResponse) => {
        cleanUp();

        if (resp.error == null) {
          _resolve({ ...resp, error: undefined });
          return;
        }

        const error = new Error(
          `WebSocket request \`${req.method}\` failed with ` +
            `code ${resp.error.code}: ${resp.error.message}`,
        );
        _resolve({ ...resp, error });
      };

      const reject = (error: Error) => {
        cleanUp();
        _reject(error);
      };

      // Signal can be aborted before the timeout is set
      if (signal?.aborted === true) {
        reject(
          new AbortError(
            'Web socket request aborted before registering timeout',
          ),
        );
        return;
      }
      signal?.addEventListener('abort', handleAbort);

      const registeredAt = Date.now();
      const registeredRequestMetadata: WsRequestMetadata = {
        ...enqueuedRequestMetadata,
        registeredAt,
        enqueuedToRegisteredDiffMs:
          registeredAt - enqueuedRequestMetadata.enqueuedAt,
      };
      reqCallbacksMap.set(id, {
        callback: resolve,
        metadata: registeredRequestMetadata,
      });
      if (isLatencyLoggingEnabled) {
        logEvent(
          `Registering WebSocket request logId='${registeredRequestMetadata.logId}' id='${registeredRequestMetadata.id}' method='${registeredRequestMetadata.method}' channel='${registeredRequestMetadata.channel}'`,
          { logId: registeredRequestMetadata.logId, registeredRequestMetadata },
        );
      }

      timeoutHandle = setTimeout(() => {
        const message = `Timed out WebSocket request logId='${registeredRequestMetadata.logId}' id='${registeredRequestMetadata.id}' method='${registeredRequestMetadata.method}' channel='${registeredRequestMetadata.channel}'`;
        const error = new TimeoutError(message);
        const requestTimeoutMetadata: WsTimeoutMetadata = {
          logId: registeredRequestMetadata.logId,
          id: registeredRequestMetadata.id,
          method: reqParams.method,
          channel: String(reqParams.params?.channel),
          timedOutAt: Date.now(),
          sentToTimeoutDiff:
            typeof registeredRequestMetadata.registeredAt === 'number'
              ? Date.now() - registeredRequestMetadata.registeredAt
              : 'undefined',
        };
        if (isTimeoutLatencyLoggingEnabled) {
          logException(error, {
            logId: requestTimeoutMetadata.logId,
            requestMetadata: registeredRequestMetadata,
            requestTimeoutMetadata,
          });
        }
        reject(error);
        // Force reconnetion to prevent the connection from getting stuck
        if (ws.readyState === WebSocket.OPEN) {
          if (isTimeoutLatencyLoggingEnabled) {
            logEvent(
              `Forcing WebSocket reconnection due to timed out WS request logId='${registeredRequestMetadata.logId}' id='${registeredRequestMetadata.id}' method='${registeredRequestMetadata.method}' channel='${registeredRequestMetadata.channel}'`,
              {
                logId: requestTimeoutMetadata.logId,
                requestMetadata: registeredRequestMetadata,
                requestTimeoutMetadata,
              },
            );
          }
          ws.reconnect();
        }
      }, timeout);
    });
  }

  function authHandler(parsedMsg: RawWSResponse): void {
    const isUnauthorized = parsedMsg.error?.code === 40111;

    if (isUnauthorized) {
      userUnauthorizedTopic.publish({});
    }
  }

  function responseHandler(parsedMsg: RawWSResponse): void {
    reqCallbacksMap.get(parsedMsg.id)?.callback(parsedMsg);
  }

  function subscriptionHandler(parsedMsg: RawNotification): void {
    const notification = processNotification(parsedMsg);
    if (notification === 'skip') return;
    notificationTopic.publish(notification);
  }

  function wsMessageListener(message: MessageEvent<string>): void {
    const parsedMsg = JSON.parse(message.data) as
      | RawWSResponse
      | RawNotification;
    const isResponse = 'id' in parsedMsg;

    if (isResponse) {
      // message due to a request (response)
      const requestMetadata = reqCallbacksMap.get(parsedMsg.id)?.metadata;
      if (requestMetadata == null) {
        // To be deployed for debugging on case-by-case basis
        // logEvent(
        //   `Unexpected WebSocket response without request metadata id='${parsedMsg.id}'`,
        //   { parsedMsg, receivedAt: Date.now() },
        // );
      }
      const requestMetadataFallback = {
        status: 'request metadata not found',
      } as const;
      const responseMetadata: WsResponseMetadata = {
        logId: String(requestMetadata?.logId),
        method: String(requestMetadata?.method),
        id: parsedMsg.id,
        channel: String(
          (parsedMsg.result as Maybe<{ channel: string }>)?.channel,
        ),
        receivedAt: Date.now(),
        sentToReceivedDiffMs:
          typeof requestMetadata?.registeredAt === 'number'
            ? Date.now() - requestMetadata.registeredAt
            : 'undefined',
      };
      const serverSideMetadata: WsResponseServerSideMetadata = {
        usDiffMs:
          typeof parsedMsg.usDiff === 'number'
            ? parsedMsg.usDiff / 1000
            : 'undefined',
        usInMicroseconds: parsedMsg.usIn ?? 'undefined',
        usOutMicroseconds: parsedMsg.usOut ?? 'undefined',
        apiWaitTimeMs: calcApiWaitTimeMs(
          requestMetadata?.enqueuedAt,
          parsedMsg.usIn,
          responseMetadata.receivedAt,
          parsedMsg.usOut,
        ),
      };

      if (isLatencyLoggingEnabled) {
        logEvent(
          `Received WebSocket response logId='${responseMetadata.logId}' id='${responseMetadata.id}' method='${responseMetadata.method}' channel='${responseMetadata.channel}'`,
          {
            logId: requestMetadata?.logId,
            requestMetadata: requestMetadata ?? requestMetadataFallback,
            responseMetadata,
            serverSideMetadata,
            parsedMsg,
          },
        );
      }

      authHandler(parsedMsg);
      responseHandler(parsedMsg);
    } else {
      // message due to a subscription (notification)
      subscriptionHandler(parsedMsg);
    }
  }

  function clearReqCallbacks() {
    reqCallbacksMap.clear();
  }

  ws.addEventListener('message', wsMessageListener);
  ws.addEventListener('close', clearReqCallbacks);

  const apiWs: ApiWebSocket<Notification> = {
    request,
    addNotificationHandler: notificationTopic.subscribe,
    removeNotificationHandler: notificationTopic.unsubscribe,
    get readyState() {
      return ws.readyState;
    },
    close: ws.close.bind(ws),
    reconnect: ws.reconnect.bind(ws),
    addEventListener: ws.addEventListener.bind(ws),
    removeEventListener: ws.removeEventListener.bind(ws),
  } as const;

  enhanceWithAuthOnReconnect(apiWs);

  return apiWs;
}
const isL2OnlySessionEnabled = FeatureFlags.getBooleanValue(
  'l2-only-session-enabled',
  false,
);

function enhanceWithAuthOnReconnect<N>(ws: ApiWebSocket<N>) {
  function authenticateOnReconnect() {
    if (getAuthToken().length === 0) return;

    const signOut = () => {
      if (isL2OnlySessionEnabled) {
        trackUserSignOut('ws_keep_alive_unauthorized', 'token-invalid');
        userUnauthorizedTopic.publish({});
        return;
      }
      trackUserSignOut('ws_keep_alive_unauthorized');
      userSignOutTopic.publish({});
    };

    (async function attemptToReauthenticate() {
      let attemptsCount = 1;
      const maxAttempts = 5;
      while (true) {
        try {
          const wsAuthResp = await ws.request({
            method: 'auth',
            params: { bearer: getAuthToken() },
            timeout: 1000 * 2 * attemptsCount,
          });

          if (wsAuthResp.error != null) {
            logEvent(
              `WS Server responded with error to auth request` +
                ` on reconnect: ${wsAuthResp.error.message}`,
            );
            signOut();
            break;
          }

          break;
        } catch (_error) {
          const error = _error as Error;
          const shouldRetry =
            isTimeoutError(error) && attemptsCount < maxAttempts;
          if (shouldRetry) {
            const message = `Timed out waiting for WS auth response on reconnect ${attemptsCount}/${maxAttempts} times. Retrying…`;
            logEvent(message, { error });
            continue;
          }
          throw error;
        } finally {
          attemptsCount++;
        }
      }
    })().catch((err) => {
      const message = 'Failed to authenticate WS connection on reconnect';
      logException(new Error(message, { cause: err }));
    });
  }

  ws.addEventListener('open', authenticateOnReconnect);
}

export function createCounter() {
  const MAX_I32 = 2147483647;
  let value = 0;

  return () => {
    if (value === MAX_I32) {
      return (value = 0);
    }
    return value++;
  };
}

/** Measures queueing time before Cloud API starts processing WS request */
function calcApiWaitTimeMs(
  enqueuedAt: UnixTimeMs | undefined,
  usIn: UnixTimeMicroseconds | undefined,
  receivedAt: UnixTimeMs,
  usOut: UnixTimeMicroseconds | undefined,
) {
  if (enqueuedAt == null) return 'undefined';
  if (usIn == null) return 'undefined';
  if (usOut == null) return 'undefined';
  return usToMs(usIn) - enqueuedAt - receivedAt + usToMs(usOut);
}
