Commit 858dc7dc authored by brent's avatar brent

add replicate

parent 4e0434ce
File added
...@@ -5,19 +5,36 @@ copyrequestbody = true ...@@ -5,19 +5,36 @@ copyrequestbody = true
[dev] [dev]
; aonet
; replicate
whoisApi = "replicate"
apikey = "Rbhpcp0FKNrYNA1nZkrwrIbD0YSSRlVG" apikey = "Rbhpcp0FKNrYNA1nZkrwrIbD0YSSRlVG"
taskUrl = "https://api.aonet.ai/api/v1" taskUrl = "https://api.aonet.ai/api/v1"
imageUrl = "https://tmp-file.aigic.ai/api/v1/upload/persistence" imageUrl = "https://tmp-file.aigic.ai/api/v1/upload/persistence"
replicateToken = "r8_9OCCea50go2Qkh0f0jhu3DbNjyzuyt61VNVI6"
replicateTimeout = 10
[test] [test]
whoisApi = "aonet"
apikey = "Rbhpcp0FKNrYNA1nZkrwrIbD0YSSRlVG" apikey = "Rbhpcp0FKNrYNA1nZkrwrIbD0YSSRlVG"
taskUrl = "https://api.aonet.ai/api/v1" taskUrl = "https://api.aonet.ai/api/v1"
imageUrl = "https://tmp-file.aigic.ai/api/v1/upload/persistence" imageUrl = "https://tmp-file.aigic.ai/api/v1/upload/persistence"
replicateToken = "r8_9OCCea50go2Qkh0f0jhu3DbNjyzuyt61VNVI6"
replicateTimeout = 10
[prod] [prod]
whoisApi = "aonet"
apikey = "Rbhpcp0FKNrYNA1nZkrwrIbD0YSSRlVG" apikey = "Rbhpcp0FKNrYNA1nZkrwrIbD0YSSRlVG"
taskUrl = "https://api.aonet.ai/api/v1" taskUrl = "https://api.aonet.ai/api/v1"
imageUrl = "https://tmp-file.aigic.ai/api/v1/upload/persistence" imageUrl = "https://tmp-file.aigic.ai/api/v1/upload/persistence"
replicateToken = "r8_9OCCea50go2Qkh0f0jhu3DbNjyzuyt61VNVI6"
replicateTimeout = 10
include "mysql.conf" include "mysql.conf"
include "mongo.conf" include "mongo.conf"
\ No newline at end of file
; include "replicate_models_version.conf"
\ No newline at end of file
pulid:
version: "43d309c37ab4e62361e5e29b8e9e867fb2dcbcec77ae91206a8d95ac5dd451a0"
url: "https://api.replicate.com/v1/predictions"
stream: false
idm-vton:
version: "906425dbca90663ff5427624839572cc56ea7d380343d13e2a4c4b09d3f0c30f"
url: "https://api.replicate.com/v1/predictions"
stream: false
face-swap:
version: "bc479d7d8ecc50eb83839af0c28210db75cac9c23837e2722028df4cddfafa22"
url: "https://api.replicate.com/v1/predictions"
stream: false
lllama3:0.0.8:
version: ""
url: "https://api.replicate.com/v1/models/meta/meta-llama-3-8b/predictions"
stream: false
meta-llama-3-8b-instruct:
version: ""
url: "https://api.replicate.com/v1/models/meta/meta-llama-3-8b-instruct/predictions"
stream: false
stable-diffusion-3:
version: ""
url: "https://api.replicate.com/v1/models/stability-ai/stable-diffusion-3/predictions"
stream: false
xtts-v2:
version: "684bc3855b37866c0c65add2ff39c78f3dea3f4ff103a436465326e0f438d55e"
url: "https://api.replicate.com/v1/predictions"
stream: false
sadtalker:
version: "a519cc0cfebaaeade068b23899165a11ec76aaa1d2b313d40d214f204ec957a3"
url: "https://api.replicate.com/v1/predictions"
stream: false
stable-diffusion:
version: "ac732df83cea7fff18b8472768c88ad041fa750ff7682a21affe81863cbe77e4"
url: "https://api.replicate.com/v1/predictions"
stream: false
\ No newline at end of file
package controllers
import (
"aon_app_server/models"
"aon_app_server/utils/mongo"
"encoding/json"
"github.com/beego/beego/v2/core/logs"
"net/http"
)
type LogController struct {
MainController
}
// Report @router /report/ [post]
func (server *LogController) Report() {
body := server.Ctx.Input.RequestBody
log := models.Log{}
err := json.Unmarshal(body, &log) //解析body中数据
logs.Debug("appRequest", log, string(body))
if err != nil {
server.respond(models.NoRequestBody, err.Error())
return
}
header := server.Ctx.Request.Header
platform := header.Get("Sec-Ch-Ua-Platform")
if log.Platform == "" && platform != "" {
log.Platform = platform
}
userAgent := header.Get("User-Agent")
if log.UserAgent == "" && userAgent != "" {
log.UserAgent = userAgent
}
isMobile := header.Get("Sec-Ch-Ua-Mobile")
if log.IsMobile == "" && isMobile != "" {
log.IsMobile = isMobile
}
secUa := header.Get("Sec-Ch-Ua")
if log.SecUa == "" && secUa != "" {
log.SecUa = secUa
}
_, err = mongo.Insert(&log)
if err != nil {
server.respond(models.BusinessFailed, err.Error())
return
}
server.respond(http.StatusOK, "")
}
// BatchReport @router /batch_report/ [post]
func (server *LogController) BatchReport() {
body := server.Ctx.Input.RequestBody
var data []models.Log
err := json.Unmarshal(body, &data) //解析body中数据
logs.Debug("appRequest", data)
if err != nil {
server.respond(models.NoRequestBody, err.Error())
return
}
header := server.Ctx.Request.Header
for _, log := range data {
platform := header.Get("Sec-Ch-Ua-Platform")
if log.Platform == "" && platform != "" {
log.Platform = platform
}
userAgent := header.Get("User-Agent")
if log.UserAgent == "" && userAgent != "" {
log.UserAgent = userAgent
}
isMobile := header.Get("Sec-Ch-Ua-Mobile")
if log.IsMobile == "" && isMobile != "" {
log.IsMobile = isMobile
}
secUa := header.Get("Sec-Ch-Ua")
if log.SecUa == "" && secUa != "" {
log.SecUa = secUa
}
}
var interfaceSlice []interface{} = make([]interface{}, len(data))
for i, v := range data {
interfaceSlice[i] = v
}
_, err = mongo.InsertMany(interfaceSlice)
if err != nil {
server.respond(models.BusinessFailed, err.Error())
return
}
server.respond(http.StatusOK, "")
}
...@@ -5,11 +5,16 @@ import ( ...@@ -5,11 +5,16 @@ import (
"aon_app_server/utils/mongo" "aon_app_server/utils/mongo"
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt"
"github.com/beego/beego/v2/core/logs" "github.com/beego/beego/v2/core/logs"
beego "github.com/beego/beego/v2/server/web" beego "github.com/beego/beego/v2/server/web"
"go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson"
"gopkg.in/yaml.v2"
"io" "io"
"net/http" "net/http"
"os"
"reflect"
"strings"
"time" "time"
) )
...@@ -17,6 +22,15 @@ type TaskController struct { ...@@ -17,6 +22,15 @@ type TaskController struct {
MainController MainController
} }
var supportModels map[string]*models.Model
func init() {
err := readYAML("./conf/replicate_models_version.yaml", &supportModels)
if err != nil {
logs.Debug("Error reading YAML file: %v", err)
}
}
var execTasks = make(chan *models.Task, 100) var execTasks = make(chan *models.Task, 100)
// Prediction @router /prediction/ [post] // Prediction @router /prediction/ [post]
...@@ -37,8 +51,13 @@ func (server *TaskController) Prediction() { ...@@ -37,8 +51,13 @@ func (server *TaskController) Prediction() {
return return
} }
task.Id = id task.Id = id
//execTasks <- &task whois, _ := beego.AppConfig.String("whoisApi")
result, err := sendTask(&task) var result *models.TaskResponse
if whois == "aonet" {
result, err = sendTask(&task)
} else if whois == "replicate" {
result, err = sendReplicate(&task)
}
if err != nil { if err != nil {
server.respond(http.StatusOK, err.Error()) server.respond(http.StatusOK, err.Error())
return return
...@@ -139,17 +158,6 @@ func sendTask(task *models.Task) (*models.TaskResponse, error) { ...@@ -139,17 +158,6 @@ func sendTask(task *models.Task) (*models.TaskResponse, error) {
task.Output = copyImages(response.Output) task.Output = copyImages(response.Output)
mongo.Update(task) mongo.Update(task)
response.Output = task.Output response.Output = task.Output
//data, err := json.Marshal(response.Output)
//if err != nil {
// logs.Info("sendTask Output Unmarshal response:", err)
// //return nil, err
// task.Error = "Output json Unmarshal error"
// mongo.Update(task)
//} else {
// task.Output = string(data)
// mongo.Update(task)
//}
} else { } else {
task.Status = 3 task.Status = 3
task.Error = response.Task task.Error = response.Task
...@@ -157,19 +165,263 @@ func sendTask(task *models.Task) (*models.TaskResponse, error) { ...@@ -157,19 +165,263 @@ func sendTask(task *models.Task) (*models.TaskResponse, error) {
if err != nil { if err != nil {
logs.Info("Update Task Error:", err) logs.Info("Update Task Error:", err)
} }
//data, err := json.Marshal(response.Task)
//if err != nil {
// logs.Info("sendTask response.Task Unmarshal response:", err)
// //return
// task.Error = "response task json Unmarshal error"
//} else {
// task.Error = string(data)
//}
//mongo.Update(task)
} }
return &response, nil return &response, nil
} }
func readYAML(filename string, out interface{}) error {
data, err := os.ReadFile(filename)
if err != nil {
return err
}
err = yaml.Unmarshal(data, out)
if err != nil {
return err
}
return nil
}
func sendReplicate(task *models.Task) (*models.TaskResponse, error) {
host, _ := beego.AppConfig.String("replicateUrl")
url := host
payload := new(bytes.Buffer)
v := reflect.ValueOf(task.Input)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
var input interface{}
for _, key := range v.MapKeys() {
value := v.MapIndex(key).Interface()
keyStr := fmt.Sprintf("%v", key.Interface())
logs.Debug("keyStr: value\n", keyStr, value)
if keyStr == "input" {
input = value
}
}
//versions := map[string]string{
// "pulid": "43d309c37ab4e62361e5e29b8e9e867fb2dcbcec77ae91206a8d95ac5dd451a0",
// "idm-vton": "906425dbca90663ff5427624839572cc56ea7d380343d13e2a4c4b09d3f0c30f",
// "face-swap": "bc479d7d8ecc50eb83839af0c28210db75cac9c23837e2722028df4cddfafa22",
//}
parts := strings.Split(task.ApiPath, "/")
lastElement := ""
if len(parts) > 2 {
lastElement = parts[len(parts)-1]
}
//version := versions[lastElement]
//version, _ := beego.AppConfig.String(lastElement)
//model, _ := config.GetSection(lastElement)
model := supportModels[lastElement]
taskReturn := &models.TaskReturn{}
taskResponse := &models.TaskResponse{
Task: taskReturn,
}
if model == nil {
task.Status = 3
taskReturn.ExecError = "It`s not open yet."
task.Error = taskResponse.Task
_, err := mongo.Update(task)
if err != nil {
logs.Info("Update Task Error:", err)
}
return taskResponse, nil
}
data := models.ReplicateRequest{
Version: model.Version,
Input: input,
Stream: model.Stream,
}
url = model.Url
json.NewEncoder(payload).Encode(data)
client := &http.Client{}
request, err := http.NewRequest("POST", url, payload)
if err != nil {
setError(task, "sendReplicate request create error:"+err.Error())
logs.Info("sendReplicate Error NewRequest request:", err)
return nil, err
}
apikey, _ := beego.AppConfig.String("replicateToken")
request.Header.Add("Authorization", "Bearer "+apikey)
logs.Info("sendReplicate client sending request:")
resp, err := client.Do(request)
if err != nil {
setError(task, "Task sending error:"+err.Error())
logs.Info("sendReplicate Error sending request:", err)
return nil, err
}
defer resp.Body.Close()
logs.Info("sendReplicate resp code", resp.StatusCode)
body, err := io.ReadAll(resp.Body)
logs.Info("sendReplicate body", string(body))
if resp.StatusCode == 200 && body == nil {
setError(task, "Task response body null")
logs.Info("sendReplicate Body reading response:", err)
return nil, err
}
if err != nil {
setError(task, "Task read response body error:"+err.Error())
logs.Info("sendReplicate Error reading response:", err)
return nil, err
}
var response models.ReplicateResponse
if err = json.Unmarshal(body, &response); err != nil {
setError(task, "Task response Unmarshal error:"+err.Error())
logs.Info("sendReplicate Error Unmarshal response:", err)
return nil, err
}
if response.Urls.Get != "" || (model.Stream && response.Urls.Stream != "") {
replicateTimeout, _ := beego.AppConfig.Int("replicateTimeout")
timeout := time.After(time.Duration(replicateTimeout) * time.Minute)
for {
select {
case <-timeout:
logs.Info("Operation timed out")
task.Status = 4 // 4表示超时状态
taskReturn.TaskError = "Operation timed out"
taskResponse.Task = taskReturn
task.Error = taskResponse.Task
_, err := mongo.Update(task)
if err != nil {
logs.Info("Update Task Error:", err)
}
return taskResponse, nil
default:
temp, err := getReplicate(response.Urls.Get)
if err != nil {
logs.Info("getReplicate Task Error:", err)
}
logs.Info("getReplicate Task temp:", temp)
if temp != nil && temp.Status == "succeeded" {
//todo 返回
var output []string
if str, ok := temp.Output.(string); ok {
fmt.Println("i 是字符串类型,值为:", str)
output = append(output, str)
}
if slice, ok := temp.Output.([]string); ok {
fmt.Println("i 是字符串数组类型,值为:", slice)
output = slice
}
if slice, ok := temp.Output.([]interface{}); ok {
fmt.Println("i 是interface{}数组类型,值为:", slice)
for _, value := range slice {
if str, ok := value.(string); ok {
fmt.Println("i 是字符串类型,值为:", str)
output = append(output, str)
}
}
}
if slice, ok := temp.Output.(map[string]string); ok {
fmt.Println("i 是map类型,值为:", slice)
for _, value := range slice {
output = append(output, value)
}
}
task.Status = 2
task.Output = output
mongo.Update(task)
taskResponse.Output = task.Output
taskResponse.Task.IsSuccess = true
taskResponse.Task.ExecCode = 200
return taskResponse, nil
} else if temp != nil && temp.Error != nil {
task.Status = 3
if response.Error != nil {
data, err := json.Marshal(response.Error)
if err != nil {
logs.Info("sendTask response.Task Unmarshal response:", err)
}
if data != nil && len(data) > 0 {
taskReturn.ExecError = string(data)
}
taskResponse.Task = taskReturn
task.Error = taskResponse.Task
}
_, err = mongo.Update(task)
if err != nil {
logs.Info("Update Task Error:", err)
}
return taskResponse, nil
}
time.Sleep(time.Second)
}
}
} else {
task.Status = 3
if response.Error != nil {
data, err := json.Marshal(response.Error)
if err != nil {
logs.Info("sendTask response.Task Unmarshal response:", err)
}
if data != nil && len(data) > 0 {
taskReturn.ExecError = string(data)
}
taskResponse.Task = taskReturn
task.Error = taskResponse.Task
}
_, err = mongo.Update(task)
if err != nil {
logs.Info("Update Task Error:", err)
}
}
return taskResponse, nil
}
func getReplicate(url string) (*models.ReplicateResponse, error) {
//host, _ := beego.AppConfig.String("replicateUrl")
//url := host + task.ApiPath
//payload := new(bytes.Buffer)
//var input interface{}
//if err := json.Unmarshal([]byte(task.Input), &input); err != nil {
// setError(task, "task.Input Unmarshal error:"+err.Error())
// logs.Info("sendTask task.Input Unmarshal response:", err)
// return nil, err
//}
//json.NewEncoder(payload).Encode(task.Input)
client := &http.Client{}
request, err := http.NewRequest("GET", url, nil)
if err != nil {
logs.Info("getReplicate Error NewRequest request:", err)
return nil, err
}
apikey, _ := beego.AppConfig.String("replicateToken")
request.Header.Add("Authorization", "Bearer "+apikey)
logs.Info("getReplicate client sending request:")
resp, err := client.Do(request)
if err != nil {
logs.Info("getReplicate Error sending request:", err)
return nil, err
}
defer resp.Body.Close()
logs.Info("getReplicate resp code", resp.StatusCode)
body, err := io.ReadAll(resp.Body)
logs.Info("getReplicate body", string(body))
if resp.StatusCode == 200 && body == nil {
logs.Info("getReplicate Body reading response:", err)
return nil, err
}
if err != nil {
logs.Info("getReplicate Error reading response:", err)
return nil, err
}
var response models.ReplicateResponse
if err = json.Unmarshal(body, &response); err != nil {
logs.Info("getReplicate Error Unmarshal response:", err)
return nil, err
}
return &response, nil
}
func setError(task *models.Task, error string) { func setError(task *models.Task, error string) {
task.Status = 3 task.Status = 3
task.Error = error task.Error = error
...@@ -213,6 +465,7 @@ func (server *TaskController) List() { ...@@ -213,6 +465,7 @@ func (server *TaskController) List() {
total, data, err := mongo.Query("Task", request.Page, request.Size, request.Filter) total, data, err := mongo.Query("Task", request.Page, request.Size, request.Filter)
if err != nil { if err != nil {
logs.Info("List Error:", err) logs.Info("List Error:", err)
server.respond(models.BusinessFailed, err.Error())
} }
var tasks []models.Task var tasks []models.Task
for _, bsonD := range data { for _, bsonD := range data {
......
...@@ -43,5 +43,5 @@ require ( ...@@ -43,5 +43,5 @@ require (
golang.org/x/text v0.14.0 // indirect golang.org/x/text v0.14.0 // indirect
golang.org/x/tools v0.6.0 // indirect golang.org/x/tools v0.6.0 // indirect
google.golang.org/protobuf v1.23.0 // indirect google.golang.org/protobuf v1.23.0 // indirect
gopkg.in/yaml.v2 v2.2.8 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect
) )
...@@ -324,6 +324,8 @@ gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= ...@@ -324,6 +324,8 @@ gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg=
......
{"/Users/brent/Documents/wubanWork/aon_app_server/controllers":1719017412063939595} {"/Users/brent/Documents/wubanWork/aon_app_server/controllers":1719995732323745955}
\ No newline at end of file \ No newline at end of file
package models
import "time"
type Log struct {
Id interface{} `json:"id" bson:"_id,omitempty"`
Version string `json:"sdk_version,omitempty" bson:"sdk_version"`
Time time.Time `json:"time" bson:"time"`
UserId string `json:"user_id" bson:"user_id"`
Device string `json:"device" bson:"device"`
Platform string `json:"platform" bson:"platform"`
Error string `json:"error" bson:"error"`
Method string `json:"method" bson:"method"`
Input interface{} `json:"input" bson:"input"`
UserAgent string `json:"user_agent" bson:"user_agent"`
IsMobile string `json:"is_mobile" bson:"is_mobile"`
SecUa string `json:"sec_ua" bson:"sec_ua"`
}
...@@ -28,13 +28,37 @@ type Task struct { ...@@ -28,13 +28,37 @@ type Task struct {
Deleted int `json:"deleted" bson:"deleted"` Deleted int `json:"deleted" bson:"deleted"`
} }
type ReplicateRequest struct {
Version string `json:"version,omitempty"`
Input interface{} `json:"input"`
Stream bool `json:"stream,omitempty"`
}
type ReplicateResponse struct {
Id string `json:"id"`
Model string `json:"model"`
Version string `json:"version"`
Input interface{} `json:"input"`
Logs string `json:"logs"`
Output interface{} `json:"output"`
DataRemoved bool `json:"data_removed"`
Error interface{} `json:"error"`
Status string `json:"status"`
CreatedAt time.Time `json:"created_at"`
Urls struct {
Cancel string `json:"cancel"`
Get string `json:"get"`
Stream string `json:"stream"`
} `json:"urls"`
}
type TaskReturn struct { type TaskReturn struct {
TaskId string `json:"task_id"` TaskId string `json:"task_id"`
TaskUid string `json:"task_uid"` TaskUid string `json:"task_uid"`
TaskFee string `json:"task_fee"` TaskFee string `json:"task_fee"`
IsSuccess bool `json:"is_success"` IsSuccess bool `json:"is_success"`
TaskError string `json:"task_error"` TaskError string `json:"task_error"`
ExecCode string `json:"exec_code"` ExecCode int `json:"exec_code"`
ExecError string `json:"exec_error"` ExecError string `json:"exec_error"`
ApiError struct { ApiError struct {
RequestId string `json:"request_id"` RequestId string `json:"request_id"`
...@@ -43,8 +67,8 @@ type TaskReturn struct { ...@@ -43,8 +67,8 @@ type TaskReturn struct {
} }
type TaskResponse struct { type TaskResponse struct {
Task TaskReturn `json:"task"` Task *TaskReturn `json:"task"`
Output []string `json:"output"` Output []string `json:"output"`
} }
type TaskResult struct { type TaskResult struct {
......
package models
var Versions = map[string]string{
"pulid": "43d309c37ab4e62361e5e29b8e9e867fb2dcbcec77ae91206a8d95ac5dd451a0",
"idm-vton": "906425dbca90663ff5427624839572cc56ea7d380343d13e2a4c4b09d3f0c30f",
"face-swap": "bc479d7d8ecc50eb83839af0c28210db75cac9c23837e2722028df4cddfafa22",
}
type Model struct {
Version string `yaml:"version"`
Url string `yaml:"url"`
Stream bool `yaml:"stream"`
}
...@@ -8,6 +8,8 @@ import ( ...@@ -8,6 +8,8 @@ import (
func init() { func init() {
beego.Router("/", &controllers.MainController{}) beego.Router("/", &controllers.MainController{})
beego.AutoPrefix("app/api/v1", &controllers.TaskController{}) beego.AutoPrefix("app/api/v1", &controllers.TaskController{})
beego.AutoPrefix("app/api/v1", &controllers.LogController{})
//ns := beego.NewNamespace("app", //ns := beego.NewNamespace("app",
// beego.NSNamespace("api", // beego.NSNamespace("api",
// beego.NSNamespace("v1", // beego.NSNamespace("v1",
......
...@@ -91,6 +91,29 @@ func Insert(i interface{}) (interface{}, error) { ...@@ -91,6 +91,29 @@ func Insert(i interface{}) (interface{}, error) {
return res.InsertedID, err return res.InsertedID, err
} }
func InsertMany(i []interface{}) ([]interface{}, error) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
collectionName := ""
for _, value := range i {
t := reflect.TypeOf(value)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
logs.Debug("结构体名称 = ", t.Name())
collectionName = t.Name()
}
collection := database.Collection(collectionName)
res, err := collection.InsertMany(ctx, i)
if err != nil {
return nil, err
}
return res.InsertedIDs, err
}
func Update(i interface{}) (interface{}, error) { func Update(i interface{}) (interface{}, error) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
...@@ -206,9 +229,31 @@ func Query(collectionName string, page int, size int, filter interface{}) (int64 ...@@ -206,9 +229,31 @@ func Query(collectionName string, page int, size int, filter interface{}) (int64
fieldName := v.Type().Field(i).Name fieldName := v.Type().Field(i).Name
fieldValue := v.Field(i).Interface() fieldValue := v.Field(i).Interface()
if fieldName == "created_time" { if fieldName == "created_time" {
t := reflect.TypeOf(fieldValue)
logs.Debug("created_time type = ", t)
interfaceSlice, ok := fieldValue.([]interface{})
if !ok {
return 0, nil, errors.New("created_time must be a Array")
}
if len(interfaceSlice) < 2 {
return 0, nil, errors.New("created_time len must = 2")
}
var createdTime []time.Time
for _, v := range interfaceSlice {
if str, ok := v.(string); ok {
startTime, err := time.Parse(time.DateTime, str)
if err != nil {
logs.Debug(err)
}
createdTime = append(createdTime, startTime)
} else {
fmt.Printf("类型断言失败,遇到非 string 类型的值: %v\n", v)
}
}
finalFilter[fieldName] = bson.M{ finalFilter[fieldName] = bson.M{
"$gte": "", "$gte": createdTime[0],
"$lt": "", "$lt": createdTime[1],
} }
continue continue
} }
...@@ -223,9 +268,31 @@ func Query(collectionName string, page int, size int, filter interface{}) (int64 ...@@ -223,9 +268,31 @@ func Query(collectionName string, page int, size int, filter interface{}) (int64
keyStr := fmt.Sprintf("%v", key.Interface()) keyStr := fmt.Sprintf("%v", key.Interface())
logs.Debug("keyStr: value\n", keyStr, value) logs.Debug("keyStr: value\n", keyStr, value)
if keyStr == "created_time" { if keyStr == "created_time" {
t := reflect.TypeOf(value)
logs.Debug("created_time type = ", t)
interfaceSlice, ok := value.([]interface{})
if !ok {
return 0, nil, errors.New("created_time must be a Array")
}
if len(interfaceSlice) < 2 {
return 0, nil, errors.New("created_time len must = 2")
}
var createdTime []time.Time
for _, v := range interfaceSlice {
if str, ok := v.(string); ok {
startTime, err := time.Parse(time.DateTime, str)
if err != nil {
logs.Debug(err)
}
createdTime = append(createdTime, startTime)
} else {
fmt.Printf("类型断言失败,遇到非 string 类型的值: %v\n", v)
}
}
finalFilter[keyStr] = bson.M{ finalFilter[keyStr] = bson.M{
"$gte": "", "$gte": createdTime[0],
"$lt": "", "$lt": createdTime[1],
} }
continue continue
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment