Commit afaa731b authored by Wade's avatar Wade

add QueryRewriteWithSummary

parent 27b3a0bb
// // Copyright 2025
// //
// // Licensed under the Apache License, Version 2.0 (the "License");
// // you may not use this file except in compliance with the License.
// // You may obtain a copy of the License at
// //
// // http://www.apache.org/licenses/LICENSE-2.0
// //
// // Unless required by applicable law or agreed to in writing, software
// // distributed under the License is distributed on an "AS IS" BASIS,
// // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// // See the License for the specific language governing permissions and
// // limitations under the License.
// //
// // SPDX-License-Identifier: Apache-2.0
// package knowledge
// import (
// "context"
// "encoding/json"
// "fmt"
// "os"
// "strings"
// "sync"
// "github.com/rs/zerolog/log"
// "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common"
// "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/errors"
// "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile"
// lkeap "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/lkeap/v20240522"
// )
// // ClientConfig holds configuration options for the Tencent Cloud LKEAP client.
// type ClientConfig struct {
// SecretID string // Tencent Cloud Secret ID
// SecretKey string // Tencent Cloud Secret Key
// Token string // Optional: Temporary token for authentication
// Endpoint string // API endpoint (default: lkeap.tencentcloudapi.com)
// Region string // Tencent Cloud region (optional)
// }
// // KnowledgeClient manages interactions with the Tencent Cloud LKEAP API.
// type KnowledgeClient struct {
// client *lkeap.Client
// config ClientConfig
// mu sync.Mutex
// initted bool
// }
// // NewKnowledgeClient creates a new KnowledgeClient with the given configuration.
// func NewKnowledgeClient(config ClientConfig) *KnowledgeClient {
// log.Info().
// Str("method", "NewKnowledgeClient").
// Str("endpoint", config.Endpoint).
// Str("region", config.Region).
// Str("secret_id", maskCredential(config.SecretID)).
// Str("token", maskCredential(config.Token)).
// Msg("Creating new KnowledgeClient")
// return &KnowledgeClient{
// config: config,
// }
// }
// // Init initializes the KnowledgeClient.
// func (kc *KnowledgeClient) Init(ctx context.Context) error {
// log.Info().Str("method", "KnowledgeClient.Init").Msg("Initializing KnowledgeClient")
// kc.mu.Lock()
// defer kc.mu.Unlock()
// if kc.initted {
// log.Error().Str("method", "KnowledgeClient.Init").Msg("Client already initialized")
// return fmt.Errorf("knowledge client already initialized")
// }
// // Load configuration from environment variables if not set
// if kc.config.SecretID == "" {
// kc.config.SecretID = os.Getenv("TENCENTCLOUD_SECRET_ID")
// }
// if kc.config.SecretKey == "" {
// kc.config.SecretKey = os.Getenv("TENCENTCLOUD_SECRET_KEY")
// }
// if kc.config.Token == "" {
// kc.config.Token = os.Getenv("TENCENTCLOUD_TOKEN")
// }
// if kc.config.Endpoint == "" {
// kc.config.Endpoint = "lkeap.tencentcloudapi.com"
// }
// if kc.config.Region == "" {
// kc.config.Region = "ap-guangzhou"
// }
// // Validate configuration
// if kc.config.SecretID == "" || kc.config.SecretKey == "" {
// log.Error().Str("method", "KnowledgeClient.Init").Msg("SecretID and SecretKey are required")
// return fmt.Errorf("knowledge: SecretID and SecretKey are required")
// }
// // Create credential
// var credential *common.Credential
// if kc.config.Token != "" {
// credential = common.NewTokenCredential(kc.config.SecretID, kc.config.SecretKey, kc.config.Token)
// log.Debug().Str("method", "KnowledgeClient.Init").Msg("Using temporary token credential")
// } else {
// credential = common.NewCredential(kc.config.SecretID, kc.config.SecretKey)
// log.Debug().Str("method", "KnowledgeClient.Init").Msg("Using standard credential")
// }
// // Create client profile
// cpf := profile.NewClientProfile()
// cpf.HttpProfile.Endpoint = kc.config.Endpoint
// // Initialize client
// client, err := lkeap.NewClient(credential, kc.config.Region, cpf)
// if err != nil {
// log.Error().
// Err(err).
// Str("method", "KnowledgeClient.Init").
// Msg("Failed to create LKEAP client")
// return err
// }
// kc.client = client
// kc.initted = true
// log.Info().Str("method", "KnowledgeClient.Init").Msg("Initialization successful")
// return nil
// }
// // QueryRewriteRequest defines the input for a query rewrite operation.
// type QueryRewriteRequest struct {
// Messages []*lkeap.Message // Multi-turn conversation history (up to 4 turns)
// Model string // Model name for query rewriting
// }
// // QueryRewriteResponse defines the output of a query rewrite operation.
// type QueryRewriteResponse struct {
// RewrittenQuery string // The rewritten query
// RawResponse *lkeap.QueryRewriteResponse
// }
// // QueryRewrite performs a query rewrite using the Tencent Cloud LKEAP API.
// func (kc *KnowledgeClient) QueryRewrite(ctx context.Context, req QueryRewriteRequest) (*QueryRewriteResponse, error) {
// log.Info().
// Str("method", "KnowledgeClient.QueryRewrite").
// Int("message_count", len(req.Messages)).
// Str("model", req.Model).
// Msg("Starting query rewrite operation")
// if !kc.initted {
// log.Error().Str("method", "KnowledgeClient.QueryRewrite").Msg("Client not initialized")
// return nil, fmt.Errorf("knowledge client not initialized; call Init first")
// }
// // Validate input
// if len(req.Messages) == 0 {
// log.Error().Str("method", "KnowledgeClient.QueryRewrite").Msg("At least one message is required")
// return nil, fmt.Errorf("at least one message is required")
// }
// if len(req.Messages) > 4 {
// log.Warn().
// Str("method", "KnowledgeClient.QueryRewrite").
// Int("message_count", len(req.Messages)).
// Msg("Message count exceeds 4, truncating to 4")
// req.Messages = req.Messages[:4]
// }
// for i, msg := range req.Messages {
// if msg.Role == nil || *msg.Role == "" {
// log.Error().
// Str("method", "KnowledgeClient.QueryRewrite").
// Int("index", i).
// Msg("Role is required in each message")
// return nil, fmt.Errorf("message at index %d missing role", i)
// }
// if *msg.Role != "user" && *msg.Role != "assistant" {
// log.Error().
// Str("method", "KnowledgeClient.QueryRewrite").
// Int("index", i).
// Str("role", *msg.Role).
// Msg("Invalid role; must be 'user' or 'assistant'")
// return nil, fmt.Errorf("invalid role '%s' at index %d", *msg.Role, i)
// }
// if msg.Content == nil || *msg.Content == "" {
// log.Error().
// Str("method", "KnowledgeClient.QueryRewrite").
// Int("index", i).
// Msg("Content is required in each message")
// return nil, fmt.Errorf("message at index %d missing content", i)
// }
// log.Debug().
// Str("method", "KnowledgeClient.QueryRewrite").
// Int("index", i).
// Str("role", *msg.Role).
// Str("content", *msg.Content).
// Msg("Validated message")
// }
// if req.Model == "" {
// log.Warn().Str("method", "KnowledgeClient.QueryRewrite").Msg("Model not specified, using default")
// req.Model = "lke-query-rewrite-base"
// }
// // Create Tencent Cloud request
// tencentReq := lkeap.NewQueryRewriteRequest()
// tencentReq.Messages = req.Messages
// if req.Model != "" {
// tencentReq.Model = common.StringPtr(req.Model)
// }
// // Debug request
// tencentReqAsJson, _ := json.Marshal(tencentReq)
// log.Debug().
// Str("method", "KnowledgeClient.QueryRewrite").
// Str("request_json", string(tencentReqAsJson)).
// Msg("Prepared Tencent Cloud request")
// // Perform request
// response, err := kc.client.QueryRewriteWithContext(ctx, tencentReq)
// if err != nil {
// if _, ok := err.(*errors.TencentCloudSDKError); ok {
// log.Error().
// Err(err).
// Str("method", "KnowledgeClient.QueryRewrite").
// Msg("Tencent Cloud API error")
// return nil, fmt.Errorf("tencent cloud api error: %w", err)
// }
// log.Error().
// Err(err).
// Str("method", "KnowledgeClient.QueryRewrite").
// Msg("Failed to perform query rewrite")
// return nil, fmt.Errorf("query rewrite failed: %w", err)
// }
// // Extract response fields
// var rewrittenQuery string
// var requestId string
// if response.Response.Content != nil {
// rewrittenQuery = *response.Response.Content
// }
// if response.Response.RequestId != nil {
// requestId = *response.Response.RequestId
// }
// result := &QueryRewriteResponse{
// RewrittenQuery: rewrittenQuery,
// RawResponse: response,
// }
// log.Info().
// Str("method", "KnowledgeClient.QueryRewrite").
// Str("rewritten_query", rewrittenQuery).
// Str("request_id", requestId).
// Interface("usage", response.Response.Usage).
// Str("raw_response", response.ToJsonString()).
// Msg("Query rewrite operation completed successfully")
// return result, nil
// }
// // QueryRewriteWithSummary wraps QueryRewrite to handle a user question, assistant answer, and history summary.
// func (kc *KnowledgeClient) QueryRewriteWithSummary(ctx context.Context, userQuestion, assistantAnswer, historySummary string) (*QueryRewriteResponse, error) {
// log.Info().
// Str("method", "KnowledgeClient.QueryRewriteWithSummary").
// Str("user_question", userQuestion).
// Str("assistant_answer", assistantAnswer).
// Str("history_summary", historySummary).
// Msg("Starting query rewrite with summary operation")
// if userQuestion == "" || assistantAnswer == "" {
// log.Error().Str("method", "KnowledgeClient.QueryRewriteWithSummary").Msg("User question and assistant answer are required")
// return nil, fmt.Errorf("user question and assistant answer are required")
// }
// // Construct messages
// messages := []*lkeap.Message{
// {
// Role: common.StringPtr("user"),
// Content: common.StringPtr(userQuestion),
// },
// {
// Role: common.StringPtr("assistant"),
// Content: common.StringPtr(assistantAnswer),
// },
// }
// // Append history summary as an assistant message if provided
// if historySummary != "" {
// messages = append(messages, &lkeap.Message{
// Role: common.StringPtr("assistant"),
// Content: common.StringPtr(fmt.Sprintf("Conversation summary: %s", historySummary)),
// })
// }
// // Create request
// req := QueryRewriteRequest{
// Messages: messages,
// Model: "lke-query-rewrite-base",
// }
// // Call QueryRewrite
// return kc.QueryRewrite(ctx, req)
// }
// // maskCredential masks sensitive credentials for logging
// func maskCredential(cred string) string {
// if len(cred) <= 8 {
// return strings.Repeat("*", len(cred))
// }
// return cred[:4] + strings.Repeat("*", len(cred)-8) + cred[len(cred)-4:]
// }
// Copyright 2025 // Copyright 2025
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
...@@ -21,6 +329,7 @@ import ( ...@@ -21,6 +329,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"os" "os"
"strings"
"sync" "sync"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
...@@ -53,6 +362,8 @@ func NewKnowledgeClient(config ClientConfig) *KnowledgeClient { ...@@ -53,6 +362,8 @@ func NewKnowledgeClient(config ClientConfig) *KnowledgeClient {
Str("method", "NewKnowledgeClient"). Str("method", "NewKnowledgeClient").
Str("endpoint", config.Endpoint). Str("endpoint", config.Endpoint).
Str("region", config.Region). Str("region", config.Region).
Str("secret_id", maskCredential(config.SecretID)).
Str("token", maskCredential(config.Token)).
Msg("Creating new KnowledgeClient") Msg("Creating new KnowledgeClient")
return &KnowledgeClient{ return &KnowledgeClient{
config: config, config: config,
...@@ -70,7 +381,7 @@ func (kc *KnowledgeClient) Init(ctx context.Context) error { ...@@ -70,7 +381,7 @@ func (kc *KnowledgeClient) Init(ctx context.Context) error {
return fmt.Errorf("knowledge client already initialized") return fmt.Errorf("knowledge client already initialized")
} }
// Load configuration defaults // Load configuration from environment variables if not set
if kc.config.SecretID == "" { if kc.config.SecretID == "" {
kc.config.SecretID = os.Getenv("TENCENTCLOUD_SECRET_ID") kc.config.SecretID = os.Getenv("TENCENTCLOUD_SECRET_ID")
} }
...@@ -84,7 +395,7 @@ func (kc *KnowledgeClient) Init(ctx context.Context) error { ...@@ -84,7 +395,7 @@ func (kc *KnowledgeClient) Init(ctx context.Context) error {
kc.config.Endpoint = "lkeap.tencentcloudapi.com" kc.config.Endpoint = "lkeap.tencentcloudapi.com"
} }
if kc.config.Region == "" { if kc.config.Region == "" {
kc.config.Region = "" // Region can be empty for global APIs kc.config.Region = "ap-guangzhou" // Default to ap-guangzhou as per curl
} }
// Validate configuration // Validate configuration
...@@ -192,23 +503,22 @@ func (kc *KnowledgeClient) QueryRewrite(ctx context.Context, req QueryRewriteReq ...@@ -192,23 +503,22 @@ func (kc *KnowledgeClient) QueryRewrite(ctx context.Context, req QueryRewriteReq
} }
if req.Model == "" { if req.Model == "" {
log.Warn().Str("method", "KnowledgeClient.QueryRewrite").Msg("Model not specified, using default") log.Warn().Str("method", "KnowledgeClient.QueryRewrite").Msg("Model not specified, using default")
req.Model = "lke-query-rewrite-base" // Default as per curl
} }
// Create Tencent Cloud request // Create Tencent Cloud request
tencentReq := lkeap.NewQueryRewriteRequest() tencentReq := lkeap.NewQueryRewriteRequest()
tencentReq.Messages = req.Messages // Directly use validated messages tencentReq.Messages = req.Messages
if req.Model != "" { if req.Model != "" {
tencentReq.Model = common.StringPtr(req.Model) tencentReq.Model = common.StringPtr(req.Model)
} }
// Debug request
tencentReqAsJson, _ := json.Marshal(tencentReq)
fmt.Println("len(tencentReq.Messages)",len(tencentReq.Messages)) log.Debug().
Str("method", "KnowledgeClient.QueryRewrite").
Str("request_json", string(tencentReqAsJson)).
tencentReqAsJson, _ := json.Marshal(tencentReq) Msg("Prepared Tencent Cloud request")
fmt.Println("len(tencentReq.Messages) json marsh",string(tencentReqAsJson))
// Perform request // Perform request
response, err := kc.client.QueryRewriteWithContext(ctx, tencentReq) response, err := kc.client.QueryRewriteWithContext(ctx, tencentReq)
...@@ -226,526 +536,83 @@ func (kc *KnowledgeClient) QueryRewrite(ctx context.Context, req QueryRewriteReq ...@@ -226,526 +536,83 @@ func (kc *KnowledgeClient) QueryRewrite(ctx context.Context, req QueryRewriteReq
Msg("Failed to perform query rewrite") Msg("Failed to perform query rewrite")
return nil, fmt.Errorf("query rewrite failed: %w", err) return nil, fmt.Errorf("query rewrite failed: %w", err)
} }
// Extract response fields
var rewrittenQuery string
var requestId string
if response.Response.Content != nil {
rewrittenQuery = *response.Response.Content
}
if response.Response.RequestId != nil {
requestId = *response.Response.RequestId
}
result := &QueryRewriteResponse{ // Extract response fields
RewrittenQuery: rewrittenQuery, var rewrittenQuery string
// Usage: response.Response.Usage, var requestId string
// RequestId: requestId, if response.Response.Content != nil {
RawResponse: response, rewrittenQuery = *response.Response.Content
}
if response.Response.RequestId != nil {
requestId = *response.Response.RequestId
}
result := &QueryRewriteResponse{
RewrittenQuery: rewrittenQuery,
RawResponse: response,
}
log.Info().
Str("method", "KnowledgeClient.QueryRewrite").
Str("rewritten_query", rewrittenQuery).
Str("request_id", requestId).
Interface("usage", response.Response.Usage).
Str("raw_response", response.ToJsonString()).
Msg("Query rewrite operation completed successfully")
return result, nil
} }
log.Info(). // maskCredential masks sensitive credentials for logging
Str("method", "KnowledgeClient.QueryRewrite"). func maskCredential(cred string) string {
Str("rewritten_query", rewrittenQuery). if len(cred) <= 8 {
Str("request_id", requestId). return strings.Repeat("*", len(cred))
Interface("usage", response.Response.Usage). }
Str("raw_response", response.ToJsonString()). return cred[:4] + strings.Repeat("*", len(cred)-8) + cred[len(cred)-4:]
Msg("Query rewrite operation completed successfully")
return result, nil
} }
// // Extract rewritten query
// var rewrittenQuery string
// if response.Response.RewrittenQuery != nil {
// rewrittenQuery = *response.Response.RewrittenQuery
// }
// result := &QueryRewriteResponse{
// RewrittenQuery: rewrittenQuery,
// RawResponse: response,
// }
// log.Info().
// Str("method", "KnowledgeClient.QueryRewrite").
// Str("rewritten_query", rewrittenQuery).
// Str("raw_response", response.ToJsonString()).
// Msg("Query rewrite operation completed successfully")
// return result, nil
// }
// // Copyright 2025
// //
// // Licensed under the Apache License, Version 2.0 (the "License");
// // you may not use this file except in compliance with the License.
// // You may obtain a copy of the License at
// //
// // http://www.apache.org/licenses/LICENSE-2.0
// //
// // Unless required by applicable law or agreed to in writing, software
// // distributed under the License is distributed on an "AS IS" BASIS,
// // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// // See the License for the specific language governing permissions and
// // limitations under the License.
// //
// // SPDX-License-Identifier: Apache-2.0
// package knowledge
// import (
// "context"
// "fmt"
// "os"
// "sync"
// "github.com/rs/zerolog/log"
// "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common"
// "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/errors"
// "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile"
// lkeap "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/lkeap/v20240522"
// )
// // ClientConfig holds configuration options for the Tencent Cloud LKEAP client.
// type ClientConfig struct {
// SecretID string // Tencent Cloud Secret ID
// SecretKey string // Tencent Cloud Secret Key
// Token string // Optional: Temporary token for authentication
// Endpoint string // API endpoint (default: lkeap.tencentcloudapi.com)
// Region string // Tencent Cloud region (optional)
// }
// // KnowledgeClient manages interactions with the Tencent Cloud LKEAP API.
// type KnowledgeClient struct {
// client *lkeap.Client
// config ClientConfig
// mu sync.Mutex
// initted bool
// }
// // NewKnowledgeClient creates a new KnowledgeClient with the given configuration.
// func NewKnowledgeClient(config ClientConfig) *KnowledgeClient {
// log.Info().
// Str("method", "NewKnowledgeClient").
// Str("endpoint", config.Endpoint).
// Str("region", config.Region).
// Msg("Creating new KnowledgeClient")
// return &KnowledgeClient{
// config: config,
// }
// }
// // Init initializes the KnowledgeClient.
// func (kc *KnowledgeClient) Init(ctx context.Context) error {
// log.Info().Str("method", "KnowledgeClient.Init").Msg("Initializing KnowledgeClient")
// kc.mu.Lock()
// defer kc.mu.Unlock()
// if kc.initted {
// log.Error().Str("method", "KnowledgeClient.Init").Msg("Client already initialized")
// return errors.New("knowledge client already initialized")
// }
// // Load configuration defaults
// if kc.config.SecretID == "" {
// kc.config.SecretID = os.Getenv("TENCENTCLOUD_SECRET_ID")
// }
// if kc.config.SecretKey == "" {
// kc.config.SecretKey = os.Getenv("TENCENTCLOUD_SECRET_KEY")
// }
// if kc.config.Token == "" {
// kc.config.Token = os.Getenv("TENCENTCLOUD_TOKEN")
// }
// if kc.config.Endpoint == "" {
// kc.config.Endpoint = "lkeap.tencentcloudapi.com"
// }
// if kc.config.Region == "" {
// kc.config.Region = "" // Region can be empty for global APIs
// }
// // Validate configuration
// if kc.config.SecretID == "" || kc.config.SecretKey == "" {
// log.Error().Str("method", "KnowledgeClient.Init").Msg("SecretID and SecretKey are required")
// return fmt.Errorf("knowledge: SecretID and SecretKey are required")
// }
// // Create credential
// var credential *common.Credential
// if kc.config.Token != "" {
// credential = common.NewTokenCredential(kc.config.SecretID, kc.config.SecretKey, kc.config.Token)
// log.Debug().Str("method", "KnowledgeClient.Init").Msg("Using temporary token credential")
// } else {
// credential = common.NewCredential(kc.config.SecretID, kc.config.SecretKey)
// log.Debug().Str("method", "KnowledgeClient.Init").Msg("Using standard credential")
// }
// // Create client profile
// cpf := profile.NewClientProfile()
// cpf.HttpProfile.Endpoint = kc.config.Endpoint
// // Initialize client
// client, err := lkeap.NewClient(credential, kc.config.Region, cpf)
// if err != nil {
// log.Error().
// Err(err).
// Str("method", "KnowledgeClient.Init").
// Msg("Failed to create LKEAP client")
// return err
// }
// kc.client = client
// kc.initted = true
// log.Info().Str("method", "KnowledgeClient.Init").Msg("Initialization successful")
// return nil
// }
// // Message represents a single turn in a conversation.
// type Message struct {
// User string // User's question
// Assistant string // Assistant's response
// }
// // QueryRewriteRequest defines the input for a query rewrite operation.
// type QueryRewriteRequest struct {
// Messages []*Message // Multi-turn conversation history (up to 4 turns)
// Model string // Model name for query rewriting
// }
// // QueryRewriteResponse defines the output of a query rewrite operation.
// type QueryRewriteResponse struct {
// RewrittenQuery string // The rewritten query
// RawResponse *lkeap.QueryRewriteResponse
// }
// // QueryRewrite performs a query rewrite using the Tencent Cloud LKEAP API.
// func (kc *KnowledgeClient) QueryRewrite(ctx context.Context, req QueryRewriteRequest) (*QueryRewriteResponse, error) {
// log.Info().
// Str("method", "KnowledgeClient.QueryRewrite").
// Int("message_count", len(req.Messages)).
// Str("model", req.Model).
// Msg("Starting query rewrite operation")
// if !kc.initted {
// log.Error().Str("method", "KnowledgeClient.QueryRewrite").Msg("Client not initialized")
// return nil, fmt.Errorf("knowledge client not initialized; call Init first")
// }
// // Validate input
// if len(req.Messages) == 0 {
// log.Error().Str("method", "KnowledgeClient.QueryRewrite").Msg("At least one message is required")
// return nil, fmt.Errorf("at least one message is required")
// }
// if len(req.Messages) > 4 {
// log.Warn().
// Str("method", "KnowledgeClient.QueryRewrite").
// Int("message_count", len(req.Messages)).
// Msg("Message count exceeds 4, truncating to 4")
// req.Messages = req.Messages[:4]
// }
// for i, msg := range req.Messages {
// if msg.User == "" || msg.Assistant == "" {
// log.Error().
// Str("method", "KnowledgeClient.QueryRewrite").
// Int("index", i).
// Msg("User and Assistant fields are required in each message")
// return nil, fmt.Errorf("message at index %d missing user or assistant", i)
// }
// }
// if req.Model == "" {
// log.Warn().Str("method", "KnowledgeClient.QueryRewrite").Msg("Model not specified, using default")
// }
// // Create Tencent Cloud request
// tencentReq := lkeap.NewQueryRewriteRequest()
// tencentReq.Messages = make([]*lkeap.Message, len(req.Messages))
// for i, msg := range req.Messages {
// tencentReq.Messages[i] = &lkeap.Message{
// User: common.StringPtr(msg.User),
// Assistant: common.StringPtr(msg.Assistant),
// }
// log.Debug().
// Str("method", "KnowledgeClient.QueryRewrite").
// Int("index", i).
// Str("user", msg.User).
// Str("assistant", msg.Assistant).
// Msg("Added message to request")
// }
// if req.Model != "" {
// tencentReq.Model = common.StringPtr(req.Model)
// }
// // Perform request
// response, err := kc.client.QueryRewriteWithContext(ctx, tencentReq)
// if err != nil {
// if _, ok := err.(*errors.TencentCloudSDKError); ok {
// log.Error().
// Err(err).
// Str("method", "KnowledgeClient.QueryRewrite").
// Msg("Tencent Cloud API error")
// return nil, fmt.Errorf("tencent cloud api error: %w", err)
// }
// log.Error().
// Err(err).
// Str("method", "KnowledgeClient.QueryRewrite").
// Msg("Failed to perform query rewrite")
// return nil, fmt.Errorf("query rewrite failed: %w", err)
// }
// // Extract rewritten query
// var rewrittenQuery string
// if response.Response.RewrittenQuery != nil {
// rewrittenQuery = *response.Response.RewrittenQuery
// }
// result := &QueryRewriteResponse{
// RewrittenQuery: rewrittenQuery,
// RawResponse: response,
// }
// log.Info().
// Str("method", "KnowledgeClient.QueryRewrite").
// Str("rewritten_query", rewrittenQuery).
// Str("raw_response", response.ToJsonString()).
// Msg("Query rewrite operation completed successfully")
// return result, nil
// }
// // Copyright 2025
// //
// // Licensed under the Apache License, Version 2.0 (the "License");
// // you may not use this file except in compliance with the License.
// // You may obtain a copy of the License at
// //
// // http://www.apache.org/licenses/LICENSE-2.0
// //
// // Unless required by applicable law or agreed to in writing, software
// // distributed under the License is distributed on an "AS IS" BASIS,
// // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// // See the License for the specific language governing permissions and
// // limitations under the License.
// //
// // SPDX-License-Identifier: Apache-2.0
// package knowledge
// import (
// "context"
// "fmt"
// "os"
// "sync"
// "github.com/rs/zerolog/log"
// "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common"
// "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/errors"
// "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile"
// lkeap "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/lkeap/v20240522"
// )
// // ClientConfig holds configuration options for the Tencent Cloud LKEAP client.
// type ClientConfig struct {
// SecretID string // Tencent Cloud Secret ID
// SecretKey string // Tencent Cloud Secret Key
// Token string // Optional: Temporary token for authentication
// Endpoint string // API endpoint (default: lkeap.tencentcloudapi.com)
// Region string // Tencent Cloud region (optional)
// }
// // KnowledgeClient manages interactions with the Tencent Cloud LKEAP API.
// type KnowledgeClient struct {
// client *lkeap.Client
// config ClientConfig
// mu sync.Mutex
// initted bool
// }
// // NewKnowledgeClient creates a new KnowledgeClient with the given configuration.
// func NewKnowledgeClient(config ClientConfig) *KnowledgeClient {
// log.Info().
// Str("method", "NewKnowledgeClient").
// Str("endpoint", config.Endpoint).
// Str("region", config.Region).
// Msg("Creating new KnowledgeClient")
// return &KnowledgeClient{
// config: config,
// }
// }
// // Init initializes the KnowledgeClient.
// func (kc *KnowledgeClient) Init(ctx context.Context) error {
// log.Info().Str("method", "KnowledgeClient.Init").Msg("Initializing KnowledgeClient")
// kc.mu.Lock()
// defer kc.mu.Unlock()
// if kc.initted { // QueryRewriteWithSummary wraps QueryRewrite to handle a user question, assistant answer, and history summary.
// log.Error().Str("method", "KnowledgeClient.Init").Msg("Client already initialized") func (kc *KnowledgeClient) QueryRewriteWithSummary(ctx context.Context, userQuestion, assistantAnswer, historySummary string) (*QueryRewriteResponse, error) {
// return fmt.Errorf("knowledge client already initialized") log.Info().
// } Str("method", "KnowledgeClient.QueryRewriteWithSummary").
Str("user_question", userQuestion).
// // Load configuration defaults Str("assistant_answer", assistantAnswer).
// if kc.config.SecretID == "" { Str("history_summary", historySummary).
// kc.config.SecretID = os.Getenv("TENCENTCLOUD_SECRET_ID") Msg("Starting query rewrite with summary operation")
// }
// if kc.config.SecretKey == "" { if userQuestion == "" || assistantAnswer == "" {
// kc.config.SecretKey = os.Getenv("TENCENTCLOUD_SECRET_KEY") log.Error().Str("method", "KnowledgeClient.QueryRewriteWithSummary").Msg("User question and assistant answer are required")
// } return nil, fmt.Errorf("user question and assistant answer are required")
// if kc.config.Token == "" { }
// kc.config.Token = os.Getenv("TENCENTCLOUD_TOKEN")
// }
// if kc.config.Endpoint == "" {
// kc.config.Endpoint = "lkeap.tencentcloudapi.com"
// }
// if kc.config.Region == "" {
// kc.config.Region = "" // Region can be empty for global APIs
// }
// // Validate configuration
// if kc.config.SecretID == "" || kc.config.SecretKey == "" {
// log.Error().Str("method", "KnowledgeClient.Init").Msg("SecretID and SecretKey are required")
// //return errors.New("knowledge: SecretID and SecretKey are required")
// return fmt.Errorf("knowledge: SecretID and SecretKey are required")
// }
// // Create credential
// var credential *common.Credential
// if kc.config.Token != "" {
// credential = common.NewTokenCredential(kc.config.SecretID, kc.config.SecretKey, kc.config.Token)
// log.Debug().Str("method", "KnowledgeClient.Init").Msg("Using temporary token credential")
// } else {
// credential = common.NewCredential(kc.config.SecretID, kc.config.SecretKey)
// log.Debug().Str("method", "KnowledgeClient.Init").Msg("Using standard credential")
// }
// // Create client profile
// cpf := profile.NewClientProfile()
// cpf.HttpProfile.Endpoint = kc.config.Endpoint
// // Initialize client
// client, err := lkeap.NewClient(credential, kc.config.Region, cpf)
// if err != nil {
// log.Error().
// Err(err).
// Str("method", "KnowledgeClient.Init").
// Msg("Failed to create LKEAP client")
// return err
// }
// kc.client = client
// kc.initted = true
// log.Info().Str("method", "KnowledgeClient.Init").Msg("Initialization successful")
// return nil
// }
// // QueryRewriteRequest defines the input for a query rewrite operation.
// type QueryRewriteRequest struct {
// Query string // The input query to rewrite
// }
// // QueryRewriteResponse defines the output of a query rewrite operation.
// type QueryRewriteResponse struct {
// RewrittenQuery string // The rewritten query
// RawResponse *lkeap.QueryRewriteResponse
// }
// // QueryRewrite performs a query rewrite using the Tencent Cloud LKEAP API.
// func (kc *KnowledgeClient) QueryRewrite(ctx context.Context, req QueryRewriteRequest) (*QueryRewriteResponse, error) {
// log.Info().
// Str("method", "KnowledgeClient.QueryRewrite").
// Str("query", req.Query).
// Msg("Starting query rewrite operation")
// if !kc.initted {
// log.Error().Str("method", "KnowledgeClient.QueryRewrite").Msg("Client not initialized")
// //return nil, errors.New("knowledge client not initialized; call Init first")
// return nil,fmt.Errorf("knowledge client not initialized; call Init first")
// }
// if req.Query == "" {
// log.Error().Str("method", "KnowledgeClient.QueryRewrite").Msg("Query is required")
// //return nil, errors.New("query is required")
// return nil, fmt.Errorf("query is required")
// }
// // Create request
// tencentReq := lkeap.NewQueryRewriteRequest()
// tencentReq.Query = common.StringPtr(req.Query)
// // Perform request
// response, err := kc.client.QueryRewriteWithContext(ctx, tencentReq)
// if err != nil {
// if _, ok := err.(*errors.TencentCloudSDKError); ok {
// log.Error().
// Err(err).
// Str("method", "KnowledgeClient.QueryRewrite").
// Msg("Tencent Cloud API error")
// return nil, fmt.Errorf("tencent cloud api error: %w", err)
// }
// log.Error().
// Err(err).
// Str("method", "KnowledgeClient.QueryRewrite").
// Msg("Failed to perform query rewrite")
// return nil, fmt.Errorf("query rewrite failed: %w", err)
// }
// // Extract rewritten query
// var rewrittenQuery string
// if response.Response.RewrittenQuery != nil {
// rewrittenQuery = *response.Response.RewrittenQuery
// }
// result := &QueryRewriteResponse{
// RewrittenQuery: rewrittenQuery,
// RawResponse: response,
// }
// log.Info().
// Str("method", "KnowledgeClient.QueryRewrite").
// Str("rewritten_query", rewrittenQuery).
// Str("raw_response", response.ToJsonString()).
// Msg("Query rewrite operation completed successfully")
// return result, nil
// }
// Construct messages
messages := []*lkeap.Message{
{
Role: common.StringPtr("user"),
Content: common.StringPtr(userQuestion),
},
{
Role: common.StringPtr("assistant"),
Content: common.StringPtr(assistantAnswer),
},
}
// Append history summary as an assistant message if provided
if historySummary != "" {
messages = append(messages, &lkeap.Message{
Role: common.StringPtr("user"),
Content: common.StringPtr(fmt.Sprintf("Conversation summary: %s", historySummary)),
})
}
// Create request
req := QueryRewriteRequest{
Messages: messages,
Model: "lke-query-rewrite-base",
}
// Call QueryRewrite
return kc.QueryRewrite(ctx, req)
}
// Copyright 2025
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package knowledge
import (
"context"
"os"
"testing"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common"
lkeap "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/lkeap/v20240522"
)
func TestMain(m *testing.M) {
// Configure zerolog for human-readable console output during tests
zerolog.TimeFieldFormat = zerolog.TimeFormatUnix
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr})
// Run tests
os.Exit(m.Run())
}
func TestKnowledgeClient_QueryRewrite(t *testing.T) {
// Warning: Do not hardcode credentials in production code. Use environment variables or a secure vault.
// The credentials below are assumed to be placeholders for testing purposes.
os.Setenv("TENCENTCLOUD_SECRET_ID", "AKID64oLfmfLtESUJ6i8LPSM4gCVbiniQuBF")
os.Setenv("TENCENTCLOUD_SECRET_KEY", "rX2JMBnBMJ2YqulOo37xa5OUMSN4Xnpd")
defer func() {
os.Unsetenv("TENCENTCLOUD_SECRET_ID")
os.Unsetenv("TENCENTCLOUD_SECRET_KEY")
}()
// Create client configuration
config := ClientConfig{
Endpoint: "lkeap.tencentcloudapi.com",
Region: "ap-guangzhou",
}
// Initialize client
client := NewKnowledgeClient(config)
ctx := context.Background()
// Test cases
tests := []struct {
name string
messages []*lkeap.Message
model string
expectError bool
}{
{
name: "ValidMultiTurnConversation",
messages: []*lkeap.Message{
{
Role: common.StringPtr("user"),
Content: common.StringPtr("What is the capital of France?"),
},
{
Role: common.StringPtr("assistant"),
Content: common.StringPtr("The capital of France is Paris."),
},
{
Role: common.StringPtr("user"),
Content: common.StringPtr("Tell me more about Paris."),
},
{
Role: common.StringPtr("assistant"),
Content: common.StringPtr("Paris is known for its art, culture, and landmarks like the Eiffel Tower."),
},
},
model: "",
expectError: true, // Expect error unless valid credentials are used
},
{
name: "EmptyMessages",
messages: []*lkeap.Message{},
model: "",
expectError: true,
},
{
name: "InvalidRole",
messages: []*lkeap.Message{
{
Role: common.StringPtr("invalid-role"),
Content: common.StringPtr("Test query"),
},
},
model: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Initialize client for each test
if err := client.Init(ctx); err != nil {
t.Fatalf("Failed to initialize KnowledgeClient: %v", err)
}
// Perform query rewrite
req := QueryRewriteRequest{
Messages: tt.messages,
Model: tt.model,
}
resp, err := client.QueryRewrite(ctx, req)
// Check error expectation
if tt.expectError {
if err == nil {
t.Error("Expected error, got none")
} else {
log.Debug().
Str("method", "TestKnowledgeClient_QueryRewrite").
Str("test_name", tt.name).
Err(err).
Msg("Received expected error")
}
return
}
// Check response (only for non-error cases)
if err != nil {
t.Fatalf("QueryRewrite failed: %v", err)
}
if resp.RewrittenQuery == "" {
t.Error("Expected non-empty rewritten query")
}
// if resp.RequestId == "" {
// t.Error("Expected non-empty request ID")
// }
// if resp.Usage == nil {
// t.Error("Expected non-nil usage")
// }
log.Info().
Str("method", "TestKnowledgeClient_QueryRewrite").
Str("test_name", tt.name).
Str("rewritten_query", resp.RewrittenQuery).
// Str("request_id", resp.RequestId).
// Interface("usage", resp.Usage).
Msg("Query rewrite successful")
})
}
}
// // Copyright 2025 // // Copyright 2025
// // // //
// // Licensed under the Apache License, Version 2.0 (the "License"); // // Licensed under the Apache License, Version 2.0 (the "License");
...@@ -217,15 +36,8 @@ func TestKnowledgeClient_QueryRewrite(t *testing.T) { ...@@ -217,15 +36,8 @@ func TestKnowledgeClient_QueryRewrite(t *testing.T) {
// os.Exit(m.Run()) // os.Exit(m.Run())
// } // }
// /*
// SecretId:AKID64oLfmfLtESUJ6i8LPSM4gCVbiniQuBF
// SecretKey:rX2JMBnBMJ2YqulOo37xa5OUMSN4Xnpd
// */
// func TestKnowledgeClient_QueryRewrite(t *testing.T) { // func TestKnowledgeClient_QueryRewrite(t *testing.T) {
// // Set up environment variables for testing (mock credentials) // // Warning: Do not hardcode credentials in production code. Use environment variables or a secure vault.
// os.Setenv("TENCENTCLOUD_SECRET_ID", "AKID64oLfmfLtESUJ6i8LPSM4gCVbiniQuBF") // os.Setenv("TENCENTCLOUD_SECRET_ID", "AKID64oLfmfLtESUJ6i8LPSM4gCVbiniQuBF")
// os.Setenv("TENCENTCLOUD_SECRET_KEY", "rX2JMBnBMJ2YqulOo37xa5OUMSN4Xnpd") // os.Setenv("TENCENTCLOUD_SECRET_KEY", "rX2JMBnBMJ2YqulOo37xa5OUMSN4Xnpd")
// defer func() { // defer func() {
...@@ -251,6 +63,25 @@ func TestKnowledgeClient_QueryRewrite(t *testing.T) { ...@@ -251,6 +63,25 @@ func TestKnowledgeClient_QueryRewrite(t *testing.T) {
// expectError bool // expectError bool
// }{ // }{
// { // {
// name: "CurlPayload",
// messages: []*lkeap.Message{
// {
// Role: common.StringPtr("user"),
// Content: common.StringPtr("你的家在哪里"),
// },
// {
// Role: common.StringPtr("assistant"),
// Content: common.StringPtr("国内"),
// },
// {
// Role: common.StringPtr("user"),
// Content: common.StringPtr("国内哪里"),
// },
// },
// model: "lke-query-rewrite-base",
// expectError: true,
// },
// {
// name: "ValidMultiTurnConversation", // name: "ValidMultiTurnConversation",
// messages: []*lkeap.Message{ // messages: []*lkeap.Message{
// { // {
...@@ -270,13 +101,13 @@ func TestKnowledgeClient_QueryRewrite(t *testing.T) { ...@@ -270,13 +101,13 @@ func TestKnowledgeClient_QueryRewrite(t *testing.T) {
// Content: common.StringPtr("Paris is known for its art, culture, and landmarks like the Eiffel Tower."), // Content: common.StringPtr("Paris is known for its art, culture, and landmarks like the Eiffel Tower."),
// }, // },
// }, // },
// model: "default-model", // model: "",
// expectError: true, // Expect error due to mock credentials // expectError: true,
// }, // },
// { // {
// name: "EmptyMessages", // name: "EmptyMessages",
// messages: []*lkeap.Message{}, // messages: []*lkeap.Message{},
// model: "default-model", // model: "",
// expectError: true, // expectError: true,
// }, // },
// { // {
...@@ -287,7 +118,7 @@ func TestKnowledgeClient_QueryRewrite(t *testing.T) { ...@@ -287,7 +118,7 @@ func TestKnowledgeClient_QueryRewrite(t *testing.T) {
// Content: common.StringPtr("Test query"), // Content: common.StringPtr("Test query"),
// }, // },
// }, // },
// model: "default-model", // model: "",
// expectError: true, // expectError: true,
// }, // },
// } // }
...@@ -320,9 +151,9 @@ func TestKnowledgeClient_QueryRewrite(t *testing.T) { ...@@ -320,9 +151,9 @@ func TestKnowledgeClient_QueryRewrite(t *testing.T) {
// return // return
// } // }
// // Check response (only for non-error cases) // // Check response
// if err != nil { // if err != nil {
// t.Fatalf("QueryRewrite failed: %v", err) // t.Errorf("QueryRewrite failed: %v", err)
// } // }
// if resp.RewrittenQuery == "" { // if resp.RewrittenQuery == "" {
// t.Error("Expected non-empty rewritten query") // t.Error("Expected non-empty rewritten query")
...@@ -335,4 +166,351 @@ func TestKnowledgeClient_QueryRewrite(t *testing.T) { ...@@ -335,4 +166,351 @@ func TestKnowledgeClient_QueryRewrite(t *testing.T) {
// Msg("Query rewrite successful") // Msg("Query rewrite successful")
// }) // })
// } // }
// } // }
\ No newline at end of file
// func TestKnowledgeClient_QueryRewriteWithSummary(t *testing.T) {
// // Warning: Do not hardcode credentials in production code. Use environment variables or a secure vault.
// os.Setenv("TENCENTCLOUD_SECRET_ID", "AKID64oLfmfLtESUJ6i8LPSM4gCVbiniQuBF")
// os.Setenv("TENCENTCLOUD_SECRET_KEY", "rX2JMBnBMJ2YqulOo37xa5OUMSN4Xnpd")
// defer func() {
// os.Unsetenv("TENCENTCLOUD_SECRET_ID")
// os.Unsetenv("TENCENTCLOUD_SECRET_KEY")
// }()
// // Create client configuration
// config := ClientConfig{
// Endpoint: "lkeap.tencentcloudapi.com",
// Region: "ap-guangzhou",
// }
// // Initialize client
// client := NewKnowledgeClient(config)
// ctx := context.Background()
// // Test cases
// tests := []struct {
// name string
// userQuestion string
// assistantAnswer string
// historySummary string
// expectError bool
// }{
// {
// name: "ValidWithSummary",
// userQuestion: "你的家在哪里",
// assistantAnswer: "国内",
// historySummary: "User asked about location preferences earlier.",
// expectError: true, // Expect error due to potentially invalid credentials
// },
// {
// name: "ValidWithoutSummary",
// userQuestion: "你的家在哪里",
// assistantAnswer: "国内",
// historySummary: "",
// expectError: true,
// },
// {
// name: "EmptyQuestion",
// userQuestion: "",
// assistantAnswer: "国内",
// historySummary: "Summary",
// expectError: true,
// },
// {
// name: "EmptyAnswer",
// userQuestion: "你的家在哪里",
// assistantAnswer: "",
// historySummary: "Summary",
// expectError: true,
// },
// }
// for _, tt := range tests {
// t.Run(tt.name, func(t *testing.T) {
// // Initialize client for each test
// if err := client.Init(ctx); err != nil {
// t.Fatalf("Failed to initialize KnowledgeClient: %v", err)
// }
// // Perform query rewrite with summary
// resp, err := client.QueryRewriteWithSummary(ctx, tt.userQuestion, tt.assistantAnswer, tt.historySummary)
// // Check error expectation
// if tt.expectError {
// if err == nil {
// t.Error("Expected error, got none")
// } else {
// log.Debug().
// Str("method", "TestKnowledgeClient_QueryWithSummary").
// Str("test_name", tt.name).
// Str("error", err.Error()).
// Msg("Received expected error")
// }
// return
// }
// // Check response
// if err != nil {
// t.Errorf("QueryRewriteWithSummary failed: %v", err)
// }
// if resp.RewrittenQuery == "" {
// t.Error("Expected non-empty rewritten query")
// }
// log.Info().
// Str("method", "TestKnowledgeClient_QueryWithSummary").
// Str("test_name", tt.name).
// Str("rewritten_query", resp.RewrittenQuery).
// Msg("Query rewrite with summary successful")
// })
// }
// }
// Copyright 2025
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package knowledge
import (
"context"
"os"
"testing"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common"
lkeap "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/lkeap/v20240522"
)
func TestMain(m *testing.M) {
// Configure zerolog for human-readable console output during tests
zerolog.TimeFieldFormat = zerolog.TimeFormatUnix
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr})
// Run tests
os.Exit(m.Run())
}
func TestKnowledgeClient_QueryRewrite(t *testing.T) {
// Warning: Do not hardcode credentials in production code. Use environment variables or a secure vault.
// The credentials below are placeholders for testing purposes.
os.Setenv("TENCENTCLOUD_SECRET_ID", "AKID64oLfmfLtESUJ6i8LPSM4gCVbiniQuBF")
os.Setenv("TENCENTCLOUD_SECRET_KEY", "rX2JMBnBMJ2YqulOo37xa5OUMSN4Xnpd")
defer func() {
os.Unsetenv("TENCENTCLOUD_SECRET_ID")
os.Unsetenv("TENCENTCLOUD_SECRET_KEY")
}()
// Create client configuration
config := ClientConfig{
Endpoint: "lkeap.tencentcloudapi.com",
Region: "ap-guangzhou",
}
// Initialize client
client := NewKnowledgeClient(config)
ctx := context.Background()
// Test cases
tests := []struct {
name string
messages []*lkeap.Message
model string
expectError bool
}{
{
name: "CurlPayload",
messages: []*lkeap.Message{
{
Role: common.StringPtr("user"),
Content: common.StringPtr("你的家在哪里"),
},
{
Role: common.StringPtr("assistant"),
Content: common.StringPtr("国内"),
},
{
Role: common.StringPtr("user"),
Content: common.StringPtr("国内哪里"),
},
},
model: "lke-query-rewrite-base",
expectError: true, // Expect error due to potentially invalid credentials
},
// {
// name: "ValidMultiTurnConversation",
// messages: []*lkeap.Message{
// {
// Role: common.StringPtr("user"),
// Content: common.StringPtr("What is the capital of France?"),
// },
// {
// Role: common.StringPtr("assistant"),
// Content: common.StringPtr("The capital of France is Paris."),
// },
// {
// Role: common.StringPtr("user"),
// Content: common.StringPtr("Tell me more about Paris."),
// },
// {
// Role: common.StringPtr("assistant"),
// Content: common.StringPtr("Paris is known for its art, culture, and landmarks like the Eiffel Tower."),
// },
// },
// model: "",
// expectError: true,
// },
// {
// name: "EmptyMessages",
// messages: []*lkeap.Message{},
// model: "",
// expectError: true,
// },
// {
// name: "InvalidRole",
// messages: []*lkeap.Message{
// {
// Role: common.StringPtr("invalid-role"),
// Content: common.StringPtr("Test query"),
// },
// },
// model: "",
// expectError: true,
// },
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Initialize client for each test
if err := client.Init(ctx); err != nil {
t.Fatalf("Failed to initialize KnowledgeClient: %v", err)
}
// Perform query rewrite
req := QueryRewriteRequest{
Messages: tt.messages,
Model: tt.model,
}
resp, err := client.QueryRewrite(ctx, req)
// Check error expectation
if tt.expectError {
if err == nil {
t.Error("Expected error, got none")
} else {
log.Debug().
Str("method", "TestKnowledgeClient_QueryRewrite").
Str("test_name", tt.name).
Err(err).
Msg("Received expected error")
}
return
}
// Check response (only for non-error cases)
if err != nil {
t.Fatalf("QueryRewrite failed: %v", err)
}
if resp.RewrittenQuery == "" {
t.Error("Expected non-empty rewritten query")
}
log.Info().
Str("method", "TestKnowledgeClient_QueryRewrite").
Str("test_name", tt.name).
Str("rewritten_query", resp.RewrittenQuery).
Msg("Query rewrite successful")
})
}
}
func TestKnowledgeClient_QueryRewriteWithSummary(t *testing.T) {
// Warning: Do not hardcode credentials in production code. Use environment variables or a secure vault.
os.Setenv("TENCENTCLOUD_SECRET_ID", "AKID64oLfmfLtESUJ6i8LPSM4gCVbiniQuBF")
os.Setenv("TENCENTCLOUD_SECRET_KEY", "rX2JMBnBMJ2YqulOo37xa5OUMSN4Xnpd")
defer func() {
os.Unsetenv("TENCENTCLOUD_SECRET_ID")
os.Unsetenv("TENCENTCLOUD_SECRET_KEY")
}()
// Create client configuration
config := ClientConfig{
Endpoint: "lkeap.tencentcloudapi.com",
Region: "ap-guangzhou",
}
// Initialize client
client := NewKnowledgeClient(config)
ctx := context.Background()
// Test cases
tests := []struct {
name string
userQuestion string
assistantAnswer string
historySummary string
expectError bool
}{
{
name: "ValidWithSummary",
userQuestion: "你的家在哪里",
assistantAnswer: "国内",
historySummary: "User asked about location preferences earlier.",
expectError: true, // Expect error due to potentially invalid credentials
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Initialize client for each test
if err := client.Init(ctx); err != nil {
t.Fatalf("Failed to initialize KnowledgeClient: %v", err)
}
// Perform query rewrite with summary
resp, err := client.QueryRewriteWithSummary(ctx, tt.userQuestion, tt.assistantAnswer, tt.historySummary)
// Check error expectation
if tt.expectError {
if err == nil {
t.Error("Expected error, got none")
} else {
log.Debug().
Str("method", "TestKnowledgeClient_QueryWithSummary").
Str("test_name", tt.name).
Str("error", err.Error()).
Msg("Received expected error")
}
return
}
// Check response
if err != nil {
t.Errorf("QueryRewriteWithSummary failed: %v", err)
}
if resp.RewrittenQuery == "" {
t.Error("Expected non-empty rewritten query")
}
log.Info().
Str("method", "TestKnowledgeClient_QueryWithSummary").
Str("test_name", tt.name).
Str("rewritten_query", resp.RewrittenQuery).
Msg("Query rewrite with summary successful")
})
}
}
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment