Commit 8b8c8c98 authored by brent's avatar brent

add async

parent 858dc7dc
...@@ -23,6 +23,11 @@ meta-llama-3-8b-instruct: ...@@ -23,6 +23,11 @@ meta-llama-3-8b-instruct:
url: "https://api.replicate.com/v1/models/meta/meta-llama-3-8b-instruct/predictions" url: "https://api.replicate.com/v1/models/meta/meta-llama-3-8b-instruct/predictions"
stream: false stream: false
meta-llama-3-70b-instruct:
version: ""
url: "https://api.replicate.com/v1/models/meta/meta-llama-3-70b-instruct/predictions"
stream: false
stable-diffusion-3: stable-diffusion-3:
version: "" version: ""
url: "https://api.replicate.com/v1/models/stability-ai/stable-diffusion-3/predictions" url: "https://api.replicate.com/v1/models/stability-ai/stable-diffusion-3/predictions"
...@@ -42,3 +47,18 @@ stable-diffusion: ...@@ -42,3 +47,18 @@ stable-diffusion:
version: "ac732df83cea7fff18b8472768c88ad041fa750ff7682a21affe81863cbe77e4" version: "ac732df83cea7fff18b8472768c88ad041fa750ff7682a21affe81863cbe77e4"
url: "https://api.replicate.com/v1/predictions" url: "https://api.replicate.com/v1/predictions"
stream: false stream: false
qwen1.5-72b:
version: "f919d3c43a8758de744cf2908426dd744154120f0a22e457a3fa647acdfe33be"
url: "https://api.replicate.com/v1/predictions"
stream: false
qwen1.5-7b:
version: "f85bec5b21ba0860e0f200be6ef5af9d5a65b974b9f99e36eb036d21eab884de"
url: "https://api.replicate.com/v1/predictions"
stream: false
llava-13b:
version: "b5f6212d032508382d61ff00469ddda3e32fd8a0e75dc39d8a4191bb742157fb"
url: "https://api.replicate.com/v1/predictions"
stream: false
\ No newline at end of file
...@@ -44,6 +44,11 @@ func (server *TaskController) Prediction() { ...@@ -44,6 +44,11 @@ func (server *TaskController) Prediction() {
return return
} }
asyncString := server.Ctx.Input.Header("Async")
async := false
if asyncString == "true" {
async = true
}
task.CreatedTime = time.Now().UTC() task.CreatedTime = time.Now().UTC()
id, err := mongo.Insert(&task) id, err := mongo.Insert(&task)
if err != nil { if err != nil {
...@@ -54,9 +59,9 @@ func (server *TaskController) Prediction() { ...@@ -54,9 +59,9 @@ func (server *TaskController) Prediction() {
whois, _ := beego.AppConfig.String("whoisApi") whois, _ := beego.AppConfig.String("whoisApi")
var result *models.TaskResponse var result *models.TaskResponse
if whois == "aonet" { if whois == "aonet" {
result, err = sendTask(&task) result, err = sendTask(&task, async)
} else if whois == "replicate" { } else if whois == "replicate" {
result, err = sendReplicate(&task) result, err = sendReplicate(&task, async)
} }
if err != nil { if err != nil {
server.respond(http.StatusOK, err.Error()) server.respond(http.StatusOK, err.Error())
...@@ -102,7 +107,7 @@ func copyImages(images []string) []string { ...@@ -102,7 +107,7 @@ func copyImages(images []string) []string {
return images return images
} }
func sendTask(task *models.Task) (*models.TaskResponse, error) { func sendTask(task *models.Task, async bool) (*models.TaskResponse, error) {
host, _ := beego.AppConfig.String("taskUrl") host, _ := beego.AppConfig.String("taskUrl")
url := host + task.ApiPath url := host + task.ApiPath
payload := new(bytes.Buffer) payload := new(bytes.Buffer)
...@@ -183,7 +188,7 @@ func readYAML(filename string, out interface{}) error { ...@@ -183,7 +188,7 @@ func readYAML(filename string, out interface{}) error {
return nil return nil
} }
func sendReplicate(task *models.Task) (*models.TaskResponse, error) { func sendReplicate(task *models.Task, async bool) (*models.TaskResponse, error) {
host, _ := beego.AppConfig.String("replicateUrl") host, _ := beego.AppConfig.String("replicateUrl")
url := host url := host
payload := new(bytes.Buffer) payload := new(bytes.Buffer)
...@@ -214,11 +219,12 @@ func sendReplicate(task *models.Task) (*models.TaskResponse, error) { ...@@ -214,11 +219,12 @@ func sendReplicate(task *models.Task) (*models.TaskResponse, error) {
//version, _ := beego.AppConfig.String(lastElement) //version, _ := beego.AppConfig.String(lastElement)
//model, _ := config.GetSection(lastElement) //model, _ := config.GetSection(lastElement)
model := supportModels[lastElement] model := supportModels[lastElement]
taskReturn := &models.TaskReturn{} taskReturn := &models.TaskReturn{
Async: async,
}
taskResponse := &models.TaskResponse{ taskResponse := &models.TaskResponse{
Task: taskReturn, Task: taskReturn,
} }
if model == nil { if model == nil {
task.Status = 3 task.Status = 3
taskReturn.ExecError = "It`s not open yet." taskReturn.ExecError = "It`s not open yet."
...@@ -274,6 +280,115 @@ func sendReplicate(task *models.Task) (*models.TaskResponse, error) { ...@@ -274,6 +280,115 @@ func sendReplicate(task *models.Task) (*models.TaskResponse, error) {
return nil, err return nil, err
} }
if response.Urls.Get != "" || (model.Stream && response.Urls.Stream != "") { if response.Urls.Get != "" || (model.Stream && response.Urls.Stream != "") {
//replicateTimeout, _ := beego.AppConfig.Int("replicateTimeout")
//
//timeout := time.After(time.Duration(replicateTimeout) * time.Minute)
//
//for {
// select {
// case <-timeout:
// logs.Info("Operation timed out")
// task.Status = 4 // 4表示超时状态
// taskReturn.TaskError = "Operation timed out"
// taskResponse.Task = taskReturn
// task.Error = taskResponse.Task
// _, err := mongo.Update(task)
// if err != nil {
// logs.Info("Update Task Error:", err)
// }
// return taskResponse, nil
// default:
// temp, err := getReplicate(response.Urls.Get)
// if err != nil {
// logs.Info("getReplicate Task Error:", err)
// }
// logs.Info("getReplicate Task temp:", temp)
// if temp != nil && temp.Status == "succeeded" {
// //todo 返回
// var output []string
// if str, ok := temp.Output.(string); ok {
// fmt.Println("i 是字符串类型,值为:", str)
// output = append(output, str)
// }
// if slice, ok := temp.Output.([]string); ok {
// fmt.Println("i 是字符串数组类型,值为:", slice)
// output = slice
// }
// if slice, ok := temp.Output.([]interface{}); ok {
// fmt.Println("i 是interface{}数组类型,值为:", slice)
// for _, value := range slice {
// if str, ok := value.(string); ok {
// fmt.Println("i 是字符串类型,值为:", str)
// output = append(output, str)
// }
// }
// }
// if slice, ok := temp.Output.(map[string]string); ok {
// fmt.Println("i 是map类型,值为:", slice)
// for _, value := range slice {
// output = append(output, value)
// }
// }
// task.Status = 2
// task.Output = output
// mongo.Update(task)
// taskResponse.Output = task.Output
// taskResponse.Task.IsSuccess = true
// taskResponse.Task.ExecCode = 200
//
// return taskResponse, nil
// } else if temp != nil && temp.Error != nil {
// task.Status = 3
// if response.Error != nil {
// data, err := json.Marshal(response.Error)
// if err != nil {
// logs.Info("sendTask response.Task Unmarshal response:", err)
// }
// if data != nil && len(data) > 0 {
// taskReturn.ExecError = string(data)
// }
// taskResponse.Task = taskReturn
// task.Error = taskResponse.Task
// }
// _, err = mongo.Update(task)
// if err != nil {
// logs.Info("Update Task Error:", err)
// }
// return taskResponse, nil
// }
// time.Sleep(time.Second)
// }
//}
if async {
go doGetReplicate(response.Urls.Get, task, taskResponse, taskReturn)
return taskResponse, nil
}
doGetReplicate(response.Urls.Get, task, taskResponse, taskReturn)
return taskResponse, nil
} else {
task.Status = 3
if response.Error != nil {
data, err := json.Marshal(response.Error)
if err != nil {
logs.Info("sendTask response.Task Unmarshal response:", err)
}
if data != nil && len(data) > 0 {
taskReturn.ExecError = string(data)
}
taskResponse.Task = taskReturn
task.Error = taskResponse.Task
}
_, err = mongo.Update(task)
if err != nil {
logs.Info("Update Task Error:", err)
}
}
return taskResponse, nil
}
func doGetReplicate(url string, task *models.Task, taskResponse *models.TaskResponse, taskReturn *models.TaskReturn) {
replicateTimeout, _ := beego.AppConfig.Int("replicateTimeout") replicateTimeout, _ := beego.AppConfig.Int("replicateTimeout")
timeout := time.After(time.Duration(replicateTimeout) * time.Minute) timeout := time.After(time.Duration(replicateTimeout) * time.Minute)
...@@ -290,9 +405,9 @@ func sendReplicate(task *models.Task) (*models.TaskResponse, error) { ...@@ -290,9 +405,9 @@ func sendReplicate(task *models.Task) (*models.TaskResponse, error) {
if err != nil { if err != nil {
logs.Info("Update Task Error:", err) logs.Info("Update Task Error:", err)
} }
return taskResponse, nil return
default: default:
temp, err := getReplicate(response.Urls.Get) temp, err := getReplicate(url)
if err != nil { if err != nil {
logs.Info("getReplicate Task Error:", err) logs.Info("getReplicate Task Error:", err)
} }
...@@ -329,12 +444,11 @@ func sendReplicate(task *models.Task) (*models.TaskResponse, error) { ...@@ -329,12 +444,11 @@ func sendReplicate(task *models.Task) (*models.TaskResponse, error) {
taskResponse.Output = task.Output taskResponse.Output = task.Output
taskResponse.Task.IsSuccess = true taskResponse.Task.IsSuccess = true
taskResponse.Task.ExecCode = 200 taskResponse.Task.ExecCode = 200
return
return taskResponse, nil
} else if temp != nil && temp.Error != nil { } else if temp != nil && temp.Error != nil {
task.Status = 3 task.Status = 3
if response.Error != nil { if temp.Error != nil {
data, err := json.Marshal(response.Error) data, err := json.Marshal(temp.Error)
if err != nil { if err != nil {
logs.Info("sendTask response.Task Unmarshal response:", err) logs.Info("sendTask response.Task Unmarshal response:", err)
} }
...@@ -348,31 +462,12 @@ func sendReplicate(task *models.Task) (*models.TaskResponse, error) { ...@@ -348,31 +462,12 @@ func sendReplicate(task *models.Task) (*models.TaskResponse, error) {
if err != nil { if err != nil {
logs.Info("Update Task Error:", err) logs.Info("Update Task Error:", err)
} }
return taskResponse, nil return
} }
time.Sleep(time.Second) time.Sleep(time.Second)
} }
} }
} else {
task.Status = 3
if response.Error != nil {
data, err := json.Marshal(response.Error)
if err != nil {
logs.Info("sendTask response.Task Unmarshal response:", err)
}
if data != nil && len(data) > 0 {
taskReturn.ExecError = string(data)
}
taskResponse.Task = taskReturn
task.Error = taskResponse.Task
}
_, err = mongo.Update(task)
if err != nil {
logs.Info("Update Task Error:", err)
}
}
return taskResponse, nil
} }
func getReplicate(url string) (*models.ReplicateResponse, error) { func getReplicate(url string) (*models.ReplicateResponse, error) {
...@@ -431,17 +526,25 @@ func setError(task *models.Task, error string) { ...@@ -431,17 +526,25 @@ func setError(task *models.Task, error string) {
} }
} }
// Result @router /result/:task_id [get] // Result @router /result/:excute_id [get]
func (server *TaskController) Result() { func (server *TaskController) Result() {
taskId := server.GetString("task_id") taskId := server.GetString("excute_id")
data := struct { if taskId == "" {
Result string `json:"result"` server.respond(models.NoRequestBody, "excute_id is null")
TaskId string `json:"task_id"` return
}{ }
Result: "success", task := models.Task{
TaskId: taskId, ExcuteId: taskId,
}
filter := map[string]string{
"excute_id": taskId,
}
err := mongo.Read(&task, filter)
if err != nil {
logs.Info("List Error:", err)
server.respond(models.BusinessFailed, err.Error())
} }
server.respond(http.StatusOK, "", data) server.respond(http.StatusOK, "", task)
} }
// List @router /list/ [post] // List @router /list/ [post]
......
{"/Users/brent/Documents/wubanWork/aon_app_server/controllers":1719995732323745955} {"/Users/brent/Documents/wubanWork/aon_app_server/controllers":1720592962344996898}
\ No newline at end of file \ No newline at end of file
...@@ -2,6 +2,7 @@ package main ...@@ -2,6 +2,7 @@ package main
import ( import (
_ "aon_app_server/routers" _ "aon_app_server/routers"
"github.com/beego/beego/v2/core/logs"
beego "github.com/beego/beego/v2/server/web" beego "github.com/beego/beego/v2/server/web"
"github.com/beego/beego/v2/server/web/filter/cors" "github.com/beego/beego/v2/server/web/filter/cors"
) )
...@@ -10,10 +11,12 @@ func init() { ...@@ -10,10 +11,12 @@ func init() {
beego.InsertFilter("*", beego.BeforeRouter, cors.Allow(&cors.Options{ beego.InsertFilter("*", beego.BeforeRouter, cors.Allow(&cors.Options{
AllowAllOrigins: true, AllowAllOrigins: true,
AllowMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, AllowMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
AllowHeaders: []string{"Origin", "Authorization", "Access-Control-Allow-Origin", "Access-Control-Allow-Headers", "Content-Type", "X-Xsrf-Token"}, AllowHeaders: []string{"Origin", "Authorization", "Async", "Access-Control-Allow-Origin", "Access-Control-Allow-Headers", "Content-Type", "X-Xsrf-Token"},
ExposeHeaders: []string{"Content-Length", "Access-Control-Allow-Origin", "Access-Control-Allow-Headers", "Content-Type", "X-Xsrf-Token", "Authorization"}, ExposeHeaders: []string{"Content-Length", "Access-Control-Allow-Origin", "Access-Control-Allow-Headers", "Content-Type", "X-Xsrf-Token", "Authorization", "Async"},
AllowCredentials: true, AllowCredentials: true,
})) }))
timeout := beego.BConfig.Listen.ServerTimeOut
logs.Debug("timeout = ", timeout)
} }
func main() { func main() {
......
...@@ -60,6 +60,7 @@ type TaskReturn struct { ...@@ -60,6 +60,7 @@ type TaskReturn struct {
TaskError string `json:"task_error"` TaskError string `json:"task_error"`
ExecCode int `json:"exec_code"` ExecCode int `json:"exec_code"`
ExecError string `json:"exec_error"` ExecError string `json:"exec_error"`
Async bool `json:"async"`
ApiError struct { ApiError struct {
RequestId string `json:"request_id"` RequestId string `json:"request_id"`
Message string `json:"message"` Message string `json:"message"`
......
package models
//type Position int
//
//const (
// TOPLEFT Position = iota + 1
// TOPRIGHT
// TOPMIDDLE
// LEFTMIDDLE
// LEFTBOTTOM
// BOTTOMMIDDLE
// RIGHTBOTTOM
// RIGHTMIDDLE
// CENTER
//)
type Position struct {
X int `json:"x,omitempty" bson:"x"`
Y int `json:"y,omitempty" bson:"y"`
}
type Template struct {
Id interface{} `json:"id" bson:"_id,omitempty"`
Logo string `json:"logo,omitempty" bson:"logo"` // 海报图片上展示的logo,非 应用logo
Image string `json:"image,omitempty" bson:"image"` //模版展示的图片
Brand string `json:"brand,omitempty" bson:"brand"` // 海报图片上展示的文字,非 应用标题
Watermark string `json:"watermark,omitempty" bson:"watermark"` //商家 的 水印
WatermarkPosition Position `json:"watermark_position,omitempty" bson:"watermark_position"` //商家 的 水印位置
Platform string `json:"platform,omitempty" bson:"platform"` //平台 的 水印
PlatformPosition Position `json:"platform_position,omitempty" bson:"platform_position"`
Prompt string `json:"prompt" bson:"prompt"`
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment