chat.ts 6.44 KB
import { service as http, API_BASE_URL } from "@/utils/request";
import type { AxiosPromise } from "axios";

// Helper to get token for Authorization header
function getToken(): string {
  return localStorage.getItem("auth_token") || "";
}

function authHeaders(): { Authorization?: string } {
  const t = getToken();
  return t ? { Authorization: `Bearer ${t}` } : {};
}

/** ---------------- Chat API ---------------- **/

interface AskQuestionPayload {
  chatId?: number;
  questionContent: string;
  fileIds?: any[];
  context?: any[];
  provider?: string;
}

interface AskQuestionResponse {
  code: number;
  data: {
    questionId: string;
    externalQuestionId: string;
  };
  message: string;
}

/**
 * 向智能体提问
 * POST /api/chat/ask
 */
export function askQuestion(
  payload: AskQuestionPayload,
): AxiosPromise<AskQuestionResponse> {
  return http.post("/chat/ask", payload);
}

interface StreamAnswerOptions {
  chatId: number | string;
  questionId?: string;
  externalQuestionId?: string;
  provider?: string; // 可选的模型提供商参数
  onMessage?: (chunk: string | object) => void;
  onOpen?: () => void;
  onError?: (err: any) => void;
  onEnd?: () => void;
  signal?: AbortSignal;
}

/**
 * 流式获取答案(SSE)
 * GET /api/chat/{chatId}/stream
 */
export function streamAnswer({
  chatId,
  questionId,
  externalQuestionId: _externalQuestionId,
  provider,
  onMessage,
  onOpen,
  onError,
  onEnd,
  signal,
}: StreamAnswerOptions): { close: () => void } {
  // 构建URL,添加questionId查询参数
  let url = `${API_BASE_URL}/chat/${encodeURIComponent(chatId)}/stream`;
  const params = new URLSearchParams();
  if (questionId) {
    params.append("questionId", questionId);
  }
  if (provider === "baichuan") {
    params.append("provider", "baichuan");
  }
  if (params.toString()) {
    url += `?${params.toString()}`;
  }

  const controller = new AbortController();
  let closed = false;

  if (signal) {
    if (signal.aborted) controller.abort();
    else
      signal.addEventListener("abort", () => controller.abort(), {
        once: true,
      });
  }

  (async () => {
    let reader: ReadableStreamDefaultReader | undefined;
    const decoder = new TextDecoder("utf-8");
    let buf = "";

    try {
      const extraHeaders = authHeaders();

      const res = await fetch(url, {
        method: "GET",
        headers: {
          Accept: "text/event-stream",
          "X-Accel-Buffering": "no",
          ...extraHeaders,
        },
        credentials: "include",
        signal: controller.signal,
      });

      if (!res.ok || !res.body)
        throw new Error(`SSE failed: ${res.status} ${res.statusText}`);

      onOpen && onOpen();

      reader = res.body.getReader();

      while (true) {
        const { done, value } = await reader.read();
        if (done) break;
        buf += decoder.decode(value, { stream: true });

        let idx;
        // 改进解析逻辑:按单行处理数据,提高流式输出的响应速度
        while ((idx = buf.indexOf("\n")) !== -1) {
          const line = buf.slice(0, idx).trim();
          buf = buf.slice(idx + 1);

          // 忽略注释行和空行
          if (!line || line.startsWith(":")) continue;

          if (line.startsWith("data:")) {
            const data = line.replace(/^data:\s?/, "");
            
            // 处理结束标记
            if (data === "[DONE]") {
              closed = true;
              break;
            }

            if (data) {
              let payload: string | object = data;
              try {
                // 尝试解析 JSON
                payload = JSON.parse(data);
              } catch (_) {
                // 解析失败则保持原样(可能是纯文本块)
              }
              onMessage && onMessage(payload);
            }
          }
        }
      }

      if (buf.trim()) {
        const data = buf
          .split("\n")
          .filter((l) => l.startsWith("data:"))
          .map((l) => l.replace(/^data:\s?/, ""))
          .join("\n");
        if (data) {
          let payload: string | object = data;
          try {
            payload = JSON.parse(data);
          } catch (_) {}
          onMessage && onMessage(payload);
        }
      }

      onEnd && onEnd();
    } catch (err: any) {
      const msg = (err && (err.name || err.message || ""))
        .toString()
        .toLowerCase();
      const isAbort =
        err?.name === "AbortError" ||
        msg.includes("aborted") ||
        msg.includes("bodystreambuffer") ||
        closed;

      if (isAbort) {
        onEnd && onEnd();
      } else {
        onError && onError(err);
      }
    } finally {
      try {
        reader && reader.releaseLock && reader.releaseLock();
      } catch {}
    }
  })();

  return {
    close: () => {
      closed = true;
      controller.abort();
    },
  };
}

/** ---------------- Chat Session API ---------------- **/

interface ChatSession {
  id: number;
  title: string;
  createdAt: string;
  deleted: boolean;
}

interface ChatMessage {
  questionId: string;
  question: string;
  answer: string;
  questionTime: string;
  answerTime: string;
  createdAt: string;
}

/**
 * 查询会话列表
 * GET /api/agent/chats?page=&size=&provider=     provider:external 豆包 /baichuan 百川 ,不传则获取全部
 */
export function listChats(
  params: { page?: number; size?: number; provider?: string } = {},
): AxiosPromise<ChatSession[]> {
  return http.get("/agent/chats", { params });
}

/**
 * 创建会话
 * POST /api/agent/chats
 */
export function createChat(payload: {
  title: string;
}): AxiosPromise<ChatSession> {
  return http.post("/agent/chats", payload);
}

/**
 * 删除会话
 * DELETE /api/agent/chats/{chatId}
 */
export function deleteChat(chatId: number | string): AxiosPromise<any> {
  return http.delete(`/agent/chats/${encodeURIComponent(chatId)}`);
}

/**
 * 查询会话的一问一答
 * GET /api/agent/chats/{chatId}/messages?afterId=&limit=
 */
export function listMessages(
  chatId: number | string,
  params: { afterId?: string; limit?: number } = {},
): AxiosPromise<ChatMessage[]> {
  return http.get(`/agent/chats/${encodeURIComponent(chatId)}/messages`, {
    params,
  });
}

/**
 * 中止正在进行的对话流
 * POST /api/chat/{chatId}/cancel
 */
export function cancelChat(chatId: number | string): AxiosPromise<any> {
  return http.post(`/chat/${encodeURIComponent(chatId)}/cancel`);
}


export default {
  askQuestion,
  streamAnswer,
  listChats,
  createChat,
  deleteChat,
  listMessages,
};