Commit 3cf4394e authored by brent's avatar brent

modify response

parent 308db6ac
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/config"
...@@ -17,6 +18,7 @@ import ( ...@@ -17,6 +18,7 @@ import (
beego "github.com/beego/beego/v2/server/web" beego "github.com/beego/beego/v2/server/web"
"github.com/fogleman/gg" "github.com/fogleman/gg"
"go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
"image" "image"
"image/jpeg" "image/jpeg"
...@@ -71,12 +73,13 @@ func (server *TaskController) Prediction() { ...@@ -71,12 +73,13 @@ func (server *TaskController) Prediction() {
} }
task.Id = id task.Id = id
whois, _ := beego.AppConfig.String("whoisApi") whois, _ := beego.AppConfig.String("whoisApi")
var result *models.TaskResponse var result interface{}
if whois == "aonet" { if whois == "aonet" {
result, err = sendTask(&task, async) result, err = sendTask(&task, async)
} else if whois == "replicate" { } else if whois == "replicate" {
result, err = sendReplicate(&task, async) result, err = sendReplicate(&task, async)
} }
if err != nil { if err != nil {
server.respond(http.StatusOK, err.Error()) server.respond(http.StatusOK, err.Error())
return return
...@@ -94,6 +97,17 @@ func (server *TaskController) Prediction() { ...@@ -94,6 +97,17 @@ func (server *TaskController) Prediction() {
server.respond(http.StatusOK, "", response) server.respond(http.StatusOK, "", response)
return return
} }
//if whois == "replicate" {
// temp := models.TaskResponseNew{}
// if err = json.Unmarshal(body, &temp); err != nil {
// server.respond(http.StatusInternalServerError, err.Error())
// return
// }
// if temp.Error != nil {
// server.respond(http.StatusInternalServerError, "", temp)
// return
// }
//}
server.respond(http.StatusOK, "", result) server.respond(http.StatusOK, "", result)
} }
...@@ -466,7 +480,12 @@ func sendTask(task *models.Task, async bool) (*models.TaskResponse, error) { ...@@ -466,7 +480,12 @@ func sendTask(task *models.Task, async bool) (*models.TaskResponse, error) {
task.UpdatedTime = time.Now().UTC() task.UpdatedTime = time.Now().UTC()
if response.Task.IsSuccess { if response.Task.IsSuccess {
task.Status = 2 task.Status = 2
task.Output = copyImages(response.Output) var output []string
if slice, ok := response.Output.([]string); ok {
fmt.Println("i 是字符串数组类型,值为:", slice)
output = slice
}
task.Output = copyImages(output)
mongo.Update(task) mongo.Update(task)
response.Output = task.Output response.Output = task.Output
} else { } else {
...@@ -517,7 +536,7 @@ func sendReplicate(task *models.Task, async bool) (*models.TaskResponse, error) ...@@ -517,7 +536,7 @@ func sendReplicate(task *models.Task, async bool) (*models.TaskResponse, error)
// "face-swap": "bc479d7d8ecc50eb83839af0c28210db75cac9c23837e2722028df4cddfafa22", // "face-swap": "bc479d7d8ecc50eb83839af0c28210db75cac9c23837e2722028df4cddfafa22",
//} //}
parts := strings.Split(task.ApiPath, "/") parts := strings.Split(task.ApiPath, "/")
lastElement := "" lastElement := task.ApiPath
if len(parts) > 2 { if len(parts) > 2 {
lastElement = parts[len(parts)-1] lastElement = parts[len(parts)-1]
} }
...@@ -694,6 +713,108 @@ func sendReplicate(task *models.Task, async bool) (*models.TaskResponse, error) ...@@ -694,6 +713,108 @@ func sendReplicate(task *models.Task, async bool) (*models.TaskResponse, error)
return taskResponse, nil return taskResponse, nil
} }
func sendReplicateNew(task *models.Task, async bool) (*models.TaskResponseNew, error) {
host, _ := beego.AppConfig.String("replicateUrl")
url := host
payload := new(bytes.Buffer)
v := reflect.ValueOf(task.Input)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
var input interface{}
for _, key := range v.MapKeys() {
value := v.MapIndex(key).Interface()
keyStr := fmt.Sprintf("%v", key.Interface())
logs.Debug("keyStr: value\n", keyStr, value)
if keyStr == "input" {
input = value
}
}
parts := strings.Split(task.ApiPath, "/")
lastElement := task.ApiPath
if len(parts) > 2 {
lastElement = parts[len(parts)-1]
}
model := supportModels[lastElement]
taskReturn := &models.TaskReturn{
Async: async,
}
if model == nil {
task.Status = 3
taskReturn.ExecError = "It`s not open yet."
task.Error = taskReturn
_, err := mongo.Update(task)
if err != nil {
logs.Info("Update Task Error:", err)
}
return nil, errors.New("It`s not open yet.")
}
data := models.ReplicateRequest{
Version: model.Version,
Input: input,
Stream: model.Stream,
}
url = model.Url
json.NewEncoder(payload).Encode(data)
client := &http.Client{}
request, err := http.NewRequest("POST", url, payload)
if err != nil {
setError(task, "sendReplicate request create error:"+err.Error())
logs.Info("sendReplicate Error NewRequest request:", err)
return nil, err
}
apikey, _ := beego.AppConfig.String("replicateToken")
request.Header.Add("Authorization", "Bearer "+apikey)
logs.Info("sendReplicate client sending request:")
resp, err := client.Do(request)
if err != nil {
setError(task, "Task sending error:"+err.Error())
logs.Info("sendReplicate Error sending request:", err)
return nil, err
}
defer resp.Body.Close()
logs.Info("sendReplicate resp code", resp.StatusCode)
body, err := io.ReadAll(resp.Body)
logs.Info("sendReplicate body", string(body))
if resp.StatusCode == 200 && body == nil {
setError(task, "Task response body null")
logs.Info("sendReplicate Body reading response:", err)
return nil, err
}
if err != nil {
setError(task, "Task read response body error:"+err.Error())
logs.Info("sendReplicate Error reading response:", err)
return nil, err
}
var response models.ReplicateResponse
if err = json.Unmarshal(body, &response); err != nil {
setError(task, "Task response Unmarshal error:"+err.Error())
logs.Info("sendReplicate Error Unmarshal response:", err)
return nil, err
}
taskResponse := &models.TaskResponseNew{}
if response.Urls.Get != "" || (model.Stream && response.Urls.Stream != "") {
if async {
go doGetReplicateNew(response.Urls.Get, task, taskResponse, taskReturn)
return taskResponse, nil
}
doGetReplicateNew(response.Urls.Get, task, taskResponse, taskReturn)
return taskResponse, nil
} else {
task.Status = 3
if response.Error != nil {
task.Error = response.Error
taskResponse.Error = response.Error
}
_, err = mongo.Update(task)
if err != nil {
logs.Info("Update Task Error:", err)
}
}
return taskResponse, nil
}
func doGetReplicate(url string, task *models.Task, taskResponse *models.TaskResponse, taskReturn *models.TaskReturn) { func doGetReplicate(url string, task *models.Task, taskResponse *models.TaskResponse, taskReturn *models.TaskReturn) {
replicateTimeout, _ := beego.AppConfig.Int("replicateTimeout") replicateTimeout, _ := beego.AppConfig.Int("replicateTimeout")
...@@ -758,11 +879,16 @@ func doGetReplicate(url string, task *models.Task, taskResponse *models.TaskResp ...@@ -758,11 +879,16 @@ func doGetReplicate(url string, task *models.Task, taskResponse *models.TaskResp
} }
task.Status = 2 task.Status = 2
isImage := checkFileIsImage(output) task.ReplicateOutput = temp.Output
if isImage { if len(output) > 0 {
task.Output = transferImagesToS3(output, task) isImage := checkFileIsImage(output)
if isImage {
task.Output = transferImagesToS3(output, task)
} else {
task.Output = transferImages(output)
}
} else { } else {
task.Output = transferImages(output) task.Output = temp.Output
} }
mongo.Update(task) mongo.Update(task)
...@@ -795,6 +921,101 @@ func doGetReplicate(url string, task *models.Task, taskResponse *models.TaskResp ...@@ -795,6 +921,101 @@ func doGetReplicate(url string, task *models.Task, taskResponse *models.TaskResp
} }
func doGetReplicateNew(url string, task *models.Task, taskResponse *models.TaskResponseNew, taskReturn *models.TaskReturn) {
replicateTimeout, _ := beego.AppConfig.Int("replicateTimeout")
timeout := time.After(time.Duration(replicateTimeout) * time.Minute)
for {
select {
case <-timeout:
logs.Info("Operation timed out")
taskResponse.Error = errors.New("Operation timed out")
task.Status = 4 // 4表示超时状态
taskReturn.TaskError = "Operation timed out"
task.Error = taskReturn
_, err := mongo.Update(task)
if err != nil {
logs.Info("Update Task Error:", err)
}
return
default:
temp, err := getReplicate(url)
if err != nil {
logs.Info("getReplicate Task Error:", err)
}
logs.Info("getReplicate Task temp:", temp)
if temp != nil && temp.Status == "succeeded" {
//todo 返回
var output []string
if str, ok := temp.Output.(string); ok {
fmt.Println("i 是字符串类型,值为:", str)
output = append(output, str)
}
if slice, ok := temp.Output.([]string); ok {
fmt.Println("i 是字符串数组类型,值为:", slice)
output = slice
}
if slice, ok := temp.Output.([]interface{}); ok {
fmt.Println("i 是interface{}数组类型,值为:", slice)
for _, value := range slice {
if str, ok := value.(string); ok {
fmt.Println("i 是字符串类型,值为:", str)
output = append(output, str)
}
}
}
if slice, ok := temp.Output.(map[string]string); ok {
fmt.Println("i 是map类型,值为:", slice)
for _, value := range slice {
output = append(output, value)
}
}
if slice, ok := temp.Output.(map[string]interface{}); ok {
fmt.Println("i 是map类型,值为:", slice)
for _, value := range slice {
logs.Info("getReplicate value:", value)
if str, ok1 := value.(string); ok1 {
fmt.Println("i 是字符串类型,值为:", str)
output = append(output, str)
}
//output = append(output, value)
}
}
task.Status = 2
task.ReplicateOutput = temp.Output
isImage := checkFileIsImage(output)
if isImage {
task.Output = transferImagesToS3(output, task)
} else {
task.Output = transferImages(output)
}
mongo.Update(task)
taskResponse.Output = temp.Output
//taskResponse.Task.IsSuccess = true
//taskResponse.Task.ExecCode = 200
return
} else if temp != nil && temp.Error != nil {
task.Status = 3
if temp.Error != nil {
taskResponse.Error = temp.Error
task.Error = temp.Error
}
_, err = mongo.Update(task)
if err != nil {
logs.Info("Update Task Error:", err)
}
return
}
time.Sleep(time.Second)
}
}
}
func getReplicate(url string) (*models.ReplicateResponse, error) { func getReplicate(url string) (*models.ReplicateResponse, error) {
//host, _ := beego.AppConfig.String("replicateUrl") //host, _ := beego.AppConfig.String("replicateUrl")
//url := host + task.ApiPath //url := host + task.ApiPath
...@@ -871,6 +1092,67 @@ func (server *TaskController) AddResult() { ...@@ -871,6 +1092,67 @@ func (server *TaskController) AddResult() {
server.respond(http.StatusOK, "") server.respond(http.StatusOK, "")
} }
func convertToProperType(value interface{}) interface{} {
switch v := value.(type) {
case primitive.M:
// 如果是 primitive.M,直接返回 map[string]interface{}
return map[string]interface{}(v)
case primitive.D:
// 如果是 primitive.D,转换为 map[string]interface{}
result := make(map[string]interface{})
for _, elem := range v {
result[elem.Key] = elem.Value
}
return result
case primitive.A:
// 如果是 primitive.A,判断其内部元素类型
if len(v) > 0 {
switch v[0].(type) {
case primitive.M:
// 如果是 []primitive.M,转换为 []map[string]interface{}
var result []map[string]interface{}
for _, elem := range v {
result = append(result, elem.(primitive.M))
}
return result
case primitive.D:
// 如果是 []primitive.D,转换为 []map[string]interface{}
var result []map[string]interface{}
for _, elem := range v {
converted := make(map[string]interface{})
for _, innerElem := range elem.(primitive.D) {
converted[innerElem.Key] = innerElem.Value
}
result = append(result, converted)
}
return result
case string:
// 如果是 []string,直接返回 []string
var result []string
for _, elem := range v {
result = append(result, elem.(string))
}
return result
}
}
case string:
// 如果是字符串,返回字符串
return v
default:
// 如果类型不匹配,返回 nil 或者一个错误提示
fmt.Println("Unsupported type")
return nil
}
return nil
}
func (server *TaskController) Result() { func (server *TaskController) Result() {
taskId := server.GetString("excute_id") taskId := server.GetString("excute_id")
if taskId == "" { if taskId == "" {
...@@ -901,6 +1183,23 @@ func (server *TaskController) Result() { ...@@ -901,6 +1183,23 @@ func (server *TaskController) Result() {
} }
task.Input = input task.Input = input
} }
if task.Output != nil {
output := convertToProperType(task.Output)
task.Output = output
//if _, ok := task.Output.(primitive.D); ok {
// raw, err := bson.Marshal(task.Output)
// if err != nil {
// server.respond(models.BusinessFailed, err.Error())
// }
// // 将 bson.Raw 解码为 User 结构体
// var output map[string]interface{}
// err = bson.Unmarshal(raw, &output)
// if err != nil {
// server.respond(models.BusinessFailed, err.Error())
// }
// task.Output = output
//}
}
if task.Error != nil { if task.Error != nil {
raw, err := bson.Marshal(task.Error) raw, err := bson.Marshal(task.Error)
...@@ -967,6 +1266,24 @@ func (server *TaskController) List() { ...@@ -967,6 +1266,24 @@ func (server *TaskController) List() {
task.Input = input task.Input = input
} }
if task.Output != nil {
output := convertToProperType(task.Output)
task.Output = output
//if _, ok := task.Output.(primitive.D); ok {
// raw, err := bson.Marshal(task.Output)
// if err != nil {
// server.respond(models.BusinessFailed, err.Error())
// }
// // 将 bson.Raw 解码为 User 结构体
// var output map[string]interface{}
// err = bson.Unmarshal(raw, &output)
// if err != nil {
// server.respond(models.BusinessFailed, err.Error())
// }
// task.Output = output
//}
}
if task.Error != nil { if task.Error != nil {
if str, ok := task.Error.(string); ok { if str, ok := task.Error.(string); ok {
fmt.Println("The string is:", str) fmt.Println("The string is:", str)
......
{"/Users/brent/Documents/wubanWork/aon_app_server/controllers":1723107582385421558} {"/Users/brent/Documents/wubanWork/aon_app_server/controllers":1723261526218747361}
\ No newline at end of file \ No newline at end of file
...@@ -14,19 +14,20 @@ const ( ...@@ -14,19 +14,20 @@ const (
) )
type Task struct { type Task struct {
Id interface{} `json:"id" bson:"_id,omitempty"` Id interface{} `json:"id" bson:"_id,omitempty"`
TaskId string `json:"task_id,omitempty" bson:"task_id"` TaskId string `json:"task_id,omitempty" bson:"task_id"`
Input interface{} `json:"input" bson:"input"` Input interface{} `json:"input" bson:"input"`
ApiPath string `json:"api_path" bson:"api_path"` ApiPath string `json:"api_path" bson:"api_path"`
UserId string `json:"user_id" bson:"user_id"` UserId string `json:"user_id" bson:"user_id"`
Status int `json:"status" bson:"status"` Status int `json:"status" bson:"status"`
Output []string `json:"output" bson:"output"` Output interface{} `json:"output" bson:"output"`
Error interface{} `json:"error" bson:"error"` ReplicateOutput interface{} `json:"replicate_output" bson:"replicate_output"`
ExcuteId string `json:"excute_id" bson:"excute_id"` Error interface{} `json:"error" bson:"error"`
AppId string `json:"app_id" bson:"app_id"` ExcuteId string `json:"excute_id" bson:"excute_id"`
CreatedTime time.Time `json:"created_time" bson:"created_time"` AppId string `json:"app_id" bson:"app_id"`
UpdatedTime time.Time `json:"updated_time" bson:"updated_time"` CreatedTime time.Time `json:"created_time" bson:"created_time"`
Deleted int `json:"deleted" bson:"deleted"` UpdatedTime time.Time `json:"updated_time" bson:"updated_time"`
Deleted int `json:"deleted" bson:"deleted"`
} }
type ReplicateRequest struct { type ReplicateRequest struct {
...@@ -70,7 +71,12 @@ type TaskReturn struct { ...@@ -70,7 +71,12 @@ type TaskReturn struct {
type TaskResponse struct { type TaskResponse struct {
Task *TaskReturn `json:"task"` Task *TaskReturn `json:"task"`
Output []string `json:"output"` Output interface{} `json:"output"`
}
type TaskResponseNew struct {
Output interface{} `json:"output,omitempty"`
Error interface{} `json:"error,omitempty"`
} }
type TaskResult struct { type TaskResult struct {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment