Commit fd96d122 authored by luxq's avatar luxq

add worker running model table

parent 5c2b377e
package db
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
type DbWorkerRunningInfo struct {
ID string `bson:"_id,omitempty" json:"id"`
WorkerId string `bson:"worker_id" json:"worker_id"`
ModelId string `bson:"model_id" json:"model_id"`
WaitTime int `bson:"wait_time" json:"wait_time"`
}
type dbWorkerRunning struct {
client *mongo.Client
col *mongo.Collection
}
func NewDBWorkerRunning(client *mongo.Client, database string, collection string) *dbWorkerRunning {
return &dbWorkerRunning{
client: client,
col: client.Database(database).Collection(collection),
}
}
func (d *dbWorkerRunning) InsertWorker(ctx context.Context, worker *DbWorkerRunningInfo) (*mongo.InsertOneResult, error) {
return d.col.InsertOne(ctx, worker)
}
func (d *dbWorkerRunning) UpdateWaitTime(ctx context.Context, id string, waitTime int) error {
update := bson.M{"$set": bson.M{"wait_time": waitTime}}
_, err := d.col.UpdateOne(ctx, bson.M{"_id": id}, update)
return err
}
func (d *dbWorkerRunning) FindWorkerByModelId(ctx context.Context, modelId string, limit int) ([]*DbWorkerRunningInfo, error) {
// find all worker that at least one running model's mode_id is equal modelId
// sort by wait time
findOptions := options.Find()
findOptions.SetLimit(int64(limit))
findOptions.SetSort(bson.D{{"wait_time", 1}})
selector := bson.M{"model_id": modelId}
cursor, err := d.col.Find(ctx, selector, findOptions)
if err != nil {
return nil, err
}
defer cursor.Close(ctx)
var workers []*DbWorkerRunningInfo
if err = cursor.All(ctx, &workers); err != nil {
return nil, err
}
return workers, nil
}
......@@ -135,5 +135,10 @@ func (d *dbWorker) FindWorkerByInstallModelAndSortByGpuRam(ctx context.Context,
return nil, err
}
return workers, nil
}
func (d *dbWorker) FindWorkerByWorkerId(ctx context.Context, workerId string) (*DbWorkerInfo, error) {
var worker *DbWorkerInfo
err := d.col.FindOne(ctx, bson.M{"worker_id": workerId}).Decode(&worker)
return worker, err
}
......@@ -19,8 +19,9 @@ import (
var (
idlist = make([]string, 0, 1000000)
//workers = make([]*DbWorkerInfo, 0, 1000000)
database = "test"
collection = "workers"
database = "test"
collection = "workers"
workerRunningCollection = "worker_running"
)
func init() {
......@@ -34,6 +35,7 @@ func init() {
func initdata(client *mongo.Client) []string {
t1 := time.Now()
db := NewDBWorker(client, database, collection)
dbRunning := NewDBWorkerRunning(client, database, workerRunningCollection)
// Insert 1,000,000 DbWorkerInfo to db
for i := 0; i < 1000; i++ {
......@@ -42,6 +44,20 @@ func initdata(client *mongo.Client) []string {
if err != nil {
panic(fmt.Sprintf("insert worker failed with err:%s", err))
}
{
// add worker running info to dbRunning
for _, model := range worker.Models.RunningModels {
runningInfo := &DbWorkerRunningInfo{
WorkerId: worker.WorkerId,
ModelId: model.ModelID,
WaitTime: model.WaitTime,
}
_, err := dbRunning.InsertWorker(context.Background(), runningInfo)
if err != nil {
panic(fmt.Sprintf("insert worker failed with err:%s", err))
}
}
}
id, ok := result.InsertedID.(primitive.ObjectID)
if !ok {
......@@ -440,3 +456,40 @@ func BenchmarkDbWorker_FindWorkerByRunningModelAndSortByWaitTime_Parallel(b *tes
}
})
}
func BenchmarkDbWorkerRunning_FindWorkerByModelId(b *testing.B) {
client, err := ConnectMongoDB("mongodb://localhost:27017")
if err != nil {
log.Fatal(err)
}
db := NewDBWorkerRunning(client, database, collection)
defer db.client.Disconnect(context.Background())
b.ResetTimer()
for i := 0; i < b.N; i++ {
modelId := getRandId(100)
if w, err := db.FindWorkerByModelId(context.Background(), modelId, 10); err != nil {
panic(fmt.Sprintf("find worker failed with err:%s", err))
} else if len(w) == 0 {
b.Logf("FindWorkerByModelId find %d with id %s\n", len(w), modelId)
}
}
}
func BenchmarkDbWorkerRunning_FindWorkerByModelId_Parallel(b *testing.B) {
client, err := ConnectMongoDB("mongodb://localhost:27017")
if err != nil {
log.Fatal(err)
}
db := NewDBWorkerRunning(client, database, workerRunningCollection)
defer db.client.Disconnect(context.Background())
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
modelId := getRandId(100)
if w, err := db.FindWorkerByModelId(context.Background(), modelId, 10); err != nil {
panic(fmt.Sprintf("find worker failed with err:%s", err))
} else if len(w) == 0 {
b.Logf("FindWorkerByModelId find %d with id %s\n", len(w), modelId)
}
}
})
}
......@@ -79,8 +79,3 @@ type WorkerInfo struct {
ModelInofs *ModelInfo `bson:"model_infos" json:"model_infos"`
Hardware *HardwareInfo `bson:"hardware" json:"hardware"`
}
type WorkerModelInfo struct {
WorkerId string `bson:"worker_id" json:"worker_id"`
ModelId string `bson:"model_id" json:"model_id"`
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment