ChatGPT是如何做到一個字一個字輸出的?
ChatGPT回答問題時,是一個字一個字彈出的,給人一種在認真思考的感覺。
這段對話的基本內容是:
問:大語言生成式模型是如何工作的
答:大型語言生成模型(如GPT-3)是基于深度學習技術的神經(jīng)網(wǎng)絡模型。這些模型通常使用一種稱為"Transformer"的架構,該架構在自然語言處理領域取得了巨大的成功。下面是大型語言生成模型的工作原理概述:
... 省略一段 ...
生成文本:一旦模型經(jīng)過訓練,就可以使用它來生成文本。給定一個初始文本或提示,模型會根據(jù)已經(jīng)學到的概率分布生成下一個單詞,然后使用生成的單詞作為輸入,并繼續(xù)生成下一個單詞,以此類推,直到達到所需的文本長度或生成終止符號。
... 省略一段 ...
大語言模型,Large Language Model,簡稱LLM。
從模型的視角來看,LLM每進行一次推理生成一個token,直到達到文本長度限制或生成終止符。
從服務端的視角來看,生成的token需要通過HTTPS協(xié)議逐個返回到瀏覽器端。
Client-Server 模式下,常規(guī)的交互方式是client端發(fā)送一次請求,接收一次響應。顯然,這無法滿足ChatGPT回復問題的場景。
其次,我們可能想到websocket,它依賴HTTP實現(xiàn)握手,升級成WebSocket。不過WebSocket需要client和server都持續(xù)占用一個socket,server側成本比較高。
ChatGPT使用的是一種折衷方案: server-sent event(簡稱SSE). 我們從OpenAI的 API 文檔可以發(fā)現(xiàn)這一點:
SSE 模式下,client只需要向server發(fā)送一次請求,server就能持續(xù)輸出,直到需要結束。整個交互過程如下圖所示:
SSE仍然使用HTTP作為應用層傳輸協(xié)議,充分利用HTTP的長連接能力,實現(xiàn)服務端推送能力。
從代碼層面來看,SSE模式與單次HTTP請求不同的點有:
- client端需要開啟 keep-alive,保證連接不會超時。
- HTTP響應的Header包含 Content-Type=text/event-stream,Cache-Cnotallow=no-cache 等。
- HTTP響應的body一般是 "data: ..." 這樣的結構。
- HTTP響應里可能有一些空數(shù)據(jù),以避免連接超時。
以 ChatGPT API 為例,在發(fā)送請求時,將stream參數(shù)設置為true就啟用了SSE特性,但在讀取數(shù)據(jù)的SDK里需要稍加注意。
在常規(guī)模式下,拿到 http.Response 后,用 ioutil.ReadAll 將數(shù)據(jù)讀出來即可,代碼如下:
func main() {
payload := strings.NewReader(`{
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "大語言生成式模型是如何工作的"}],
"max_tokens": 1024,
"temperature": 1,
"top_p": 1,
"n": 1,
"stream": false
}`)
client := &http.Client{}
req, _ := http.NewRequest("POST", "https://api.openai.com/v1/chat/completions", payload)
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Authorization", "Bearer <OpenAI-Token>")
resp, err := client.Do(req)
if err != nil {
fmt.Println(err)
return
}
defer resp.Body.Close()
body, _ := ioutil.ReadAll(resp.Body)
fmt.Println(string(body))
}
執(zhí)行大概耗費20s+,得到一個完整的結果:
{
"id": "chatcmpl-7KklTf9mag5tyBXLEqM3PWQn4jlfD",
"object": "chat.completion",
"created": 1685180679,
"model": "gpt-3.5-turbo-0301",
"usage": {
"prompt_tokens": 21,
"completion_tokens": 358,
"total_tokens": 379
},
"choices": [
{
"message": {
"role": "assistant",
"content": "大語言生成式模型通常采用神經(jīng)網(wǎng)絡來實現(xiàn),具體工作流程如下:\n\n1. 數(shù)據(jù)預處理:將語料庫中的文本數(shù)據(jù)進行預處理,包括分詞、刪除停用詞(如“的”、“了”等常用詞匯)、去重等操作,以減少冗余信息。\n\n2. 模型訓練:采用遞歸神經(jīng)網(wǎng)絡(RNN)、長短期記憶網(wǎng)絡(LSTM)或變種的Transformers等模型進行訓練,這些模型都具有一定的記憶能力,可以學習到語言的一定規(guī)律,并預測下一個可能出現(xiàn)的詞語。\n\n3. 模型應用:當模型完成訓練后,可以將其應用于實際的生成任務中。模型接收一個輸入文本串,并預測下一個可能出現(xiàn)的詞語,直到達到一定長度或遇到結束符號為止。\n\n4. 根據(jù)生成結果對模型進行調優(yōu):生成的結果需要進行評估,如計算生成文本與語料庫文本的相似度、流暢度等指標,以此來調優(yōu)模型,提高其生成質量。\n\n總體而言,大語言生成式模型通過對語言的規(guī)律學習,從而生成高質量的文本。"
},
"finish_reason": "stop",
"index": 0
}
]
}
如果我們將 stream 設置為 true,不做任何修改,請求總消耗28s+,體現(xiàn)為很多條 stream 消息:
上面這張圖是一張Postman調用 chatgpt api的圖,走的就是 ioutil.ReadAll 的模式。為了實現(xiàn)stream讀取,我們可以分段讀取 http.Response.Body。下面是這種方式可行的原因:
- http.Response.Body 的類型是 io.ReaderCloser,底層依賴一個HTTP連接,支持stream讀。
- SSE 返回的數(shù)據(jù)通過換行符\n進行分割
所以修正的方法是通過bufio.NewReader(resp.Body)包裝起來,并在一個for-loop里讀取, 代碼如下:
// stream event 結構體定義
type ChatCompletionRspChoiceItem struct {
Delta map[string]string `json:"delta,omitempty"` // 只有 content 字段
Index int `json:"index,omitempty"`
Logprobs *int `json:"logprobs,omitempty"`
FinishReason string `json:"finish_reason,omitempty"`
}
type ChatCompletionRsp struct {
ID string `json:"id"`
Object string `json:"object"`
Created int `json:"created"` // unix second
Model string `json:"model"`
Choices []ChatCompletionRspChoiceItem `json:"choices"`
}
func main() {
payload := strings.NewReader(`{
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "大語言生成式模型是如何工作的"}],
"max_tokens": 1024,
"temperature": 1,
"top_p": 1,
"n": 1,
"stream": true
}`)
client := &http.Client{}
req, _ := http.NewRequest("POST", "https://api.openai.com/v1/chat/completions", payload)
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Authorization", "Bearer "+apiKey)
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Cache-Control", "no-cache")
req.Header.Set("Connection", "keep-alive")
resp, err := client.Do(req)
if err != nil {
fmt.Println(err)
return
}
defer resp.Body.Close()
reader := bufio.NewReader(resp.Body)
for {
line, err := reader.ReadBytes('\n')
if err != nil {
if err == io.EOF {
// 忽略 EOF 錯誤
break
} else {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
fmt.Printf("[PostStream] fails to read response body, timeout\n")
} else {
fmt.Printf("[PostStream] fails to read response body, err=%s\n", err)
}
}
break
}
line = bytes.TrimSuffix(line, []byte{'\n'})
line = bytes.TrimPrefix(line, []byte("data: "))
if bytes.Equal(line, []byte("[DONE]")) {
break
} else if len(line) > 0 {
var chatCompletionRsp ChatCompletionRsp
if err := json.Unmarshal(line, &chatCompletionRsp); err == nil {
fmt.Printf(chatCompletionRsp.Choices[0].Delta["content"])
} else {
fmt.Printf("\ninvalid line=%s\n", line)
}
}
}
fmt.Println("the end")
}
看完client端,我們再看server端?,F(xiàn)在我們嘗試mock chatgpt server逐字返回一段文字。這里涉及到兩個點:
- Response Header 需要設置 Connection 為 keep-alive 和 Content-Type 為 text/event-stream。
- 寫入 respnose 以后,需要flush到client端。
代碼如下:
func streamHandler(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Connection", "keep-alive")
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
var chatCompletionRsp ChatCompletionRsp
runes := []rune(`大語言生成式模型通常使用深度學習技術,例如循環(huán)神經(jīng)網(wǎng)絡(RNN)或變壓器(Transformer)來建模語言的概率分布。這些模型接收前面的詞匯序列,并利用其內部神經(jīng)網(wǎng)絡結構預測下一個詞匯的概率分布。然后,模型將概率最高的詞匯作為生成的下一個詞匯,并遞歸地生成一個詞匯序列,直到到達最大長度或遇到一個終止符號。
在訓練過程中,模型通過最大化生成的文本樣本的概率分布來學習有效的參數(shù)。為了避免模型產生過于平凡的、重復的、無意義的語言,我們通常會引入一些技巧,如dropout、序列擾動等。
大語言生成模型的重要應用包括文本生成、問答系統(tǒng)、機器翻譯、對話建模、摘要生成、文本分類等。`)
for _, r := range runes {
chatCompletionRsp.Choices = []ChatCompletionRspChoiceItem{
{Delta: map[string]string{"content": string(r)}},
}
bs, _ := json.Marshal(chatCompletionRsp)
line := fmt.Sprintf("data: %s\n", bs)
fmt.Fprintf(w, line)
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
time.Sleep(time.Millisecond * 100)
}
fmt.Fprintf(w, "data: [DONE]\n")
}
func main() {
http.HandleFunc("/stream", streamHandler)
http.ListenAndServe(":8088", nil)
}
在真實場景中,要返回的數(shù)據(jù)來源于另一個服務或函數(shù)調用,如果這個服務或函數(shù)調用返回時間不穩(wěn)定,可能導致client端長時間收不到消息,所以一般的處理方式是:
- 對第三方的調用放到一個 goroutine 中。
- 通過 time.Tick 創(chuàng)建一個定時器,向client端發(fā)送空消息。
- 創(chuàng)建一個timeout channel,避免響應時間太久。
為了能夠從不同的channel讀取數(shù)據(jù),select 是一個不錯的關鍵字,比如這段演示代碼:
// 聲明一個 event channel
// 聲明一個 time.Tick channel
// 聲明一個 timeout channel
select {
case ev := <-events:
// send data event
case <- timeTick:
// send empty event
case <-timeout:
fmt.Fprintf(w, "[Done]\n\n")
}
小結一下
大語言模型生成響應整個結果的過程是比較漫長的,但逐token生成的響應比較快,ChatGPT將這一特性與SSE技術充分結合,一個字一個字地彈出回復,在用戶體驗上實現(xiàn)了質的提升。
縱觀生成式模型,不管是LLAMA/小羊駝 (不能商用),還是Stable Diffusion/Midjourney。在提供線上服務時,均可利用SSE技術節(jié)省提升用戶體驗,節(jié)省服務器資源。