package main

import (
	"context"
	"database/sql"
	"flag"
	"fmt"
	"log"
	"time"

	_ "github.com/lib/pq"
)

var (
	connString = flag.String("dbconn", "", "database connection string")
)

// QA 结构体表示 qa 表的记录
type QA struct {
	ID        int64     // 主键
	CreatedAt time.Time // 创建时间
	UserID    *int64    // 可空的用户 ID
	Username  *string   // 可空的用户名
	Question  *string   // 可空的问题
	Answer    *string   // 可空的答案
}

// QAStore 定义 DAO 接口
type QAStore interface {
	// GetLatestQA 从 latest_qa 视图读取指定 user_id 的最新记录
	GetLatestQA(ctx context.Context, userID *int64) ([]QA, error)
	// WriteQA 插入或更新 qa 表记录
	WriteQA(ctx context.Context, qa QA) (int64, error)
}

// qaStore 是 QAStore 接口的实现
type qaStore struct {
	db *sql.DB
}

// NewQAStore 创建新的 QAStore 实例
func NewQAStore(db *sql.DB) QAStore {
	return &qaStore{db: db}
}

// GetLatestQA 从 latest_qa 视图读取数据
func (s *qaStore) GetLatestQA(ctx context.Context, userID *int64) ([]QA, error) {
	query := `
        SELECT id, created_at, user_id, username, question, answer
        FROM latest_qa
        WHERE user_id = $1 OR (user_id IS NULL AND $1 IS NULL)`
	args := []interface{}{userID}
	if userID == nil {
		args = []interface{}{nil}
	}

	rows, err := s.db.QueryContext(ctx, query, args...)
	if err != nil {
		return nil, fmt.Errorf("query latest_qa: %w", err)
	}
	defer rows.Close()

	var results []QA
	for rows.Next() {
		var qa QA
		var userIDVal sql.NullInt64
		var username, question, answer sql.NullString
		if err := rows.Scan(&qa.ID, &qa.CreatedAt, &userIDVal, &username, &question, &answer); err != nil {
			return nil, fmt.Errorf("scan row: %w", err)
		}
		if userIDVal.Valid {
			qa.UserID = &userIDVal.Int64
		}
		if username.Valid {
			qa.Username = &username.String
		}
		if question.Valid {
			qa.Question = &question.String
		}
		if answer.Valid {
			qa.Answer = &answer.String
		}
		results = append(results, qa)
	}
	if err := rows.Err(); err != nil {
		return nil, fmt.Errorf("row iteration: %w", err)
	}
	return results, nil
}

// WriteQA 插入或更新 qa 表记录
func (s *qaStore) WriteQA(ctx context.Context, qa QA) (int64, error) {
	if qa.ID != 0 {
		// 更新记录
		query := `
            UPDATE qa
            SET user_id = $1, username = $2, question = $3, answer = $4
            WHERE id = $5
            RETURNING id`
		var updatedID int64
		err := s.db.QueryRowContext(ctx, query, qa.UserID, qa.Username, qa.Question, qa.Answer, qa.ID).Scan(&updatedID)
		if err == sql.ErrNoRows {
			return 0, fmt.Errorf("no record found with id %d", qa.ID)
		}
		if err != nil {
			return 0, fmt.Errorf("update qa: %w", err)
		}
		return updatedID, nil
	}

	// 插入新记录
	query := `
        INSERT INTO qa (user_id, username, question, answer)
        VALUES ($1, $2, $3, $4)
        RETURNING id`
	var newID int64
	err := s.db.QueryRowContext(ctx, query, qa.UserID, qa.Username, qa.Question, qa.Answer).Scan(&newID)
	if err != nil {
		return 0, fmt.Errorf("insert qa: %w", err)
	}
	return newID, nil
}

func mainQA() {
	flag.Parse()
	ctx := context.Background()

	if *connString == "" {
		log.Fatal("need -dbconn")
	}

	db, err := sql.Open("postgres", *connString)
	if err != nil {
		log.Fatalf("open database: %v", err)
	}
	defer db.Close()

	store := NewQAStore(db)

	// 示例：读取 user_id=101 的最新 QA
	results, err := store.GetLatestQA(ctx, int64Ptr(101))
	if err != nil {
		log.Fatalf("get latest QA: %v", err)
	}
	for _, qa := range results {
		fmt.Printf("ID: %d, CreatedAt: %v, UserID: %v, Username: %v, Question: %v, Answer: %v\n",
			qa.ID, qa.CreatedAt, derefInt64(qa.UserID), derefString(qa.Username), derefString(qa.Question), derefString(qa.Answer))
	}

	// 示例：插入新 QA
	newQA := QA{
		UserID:   int64Ptr(101),
		Username: stringPtr("alice"),
		Question: stringPtr("What is AI?"),
		Answer:   stringPtr("AI is..."),
	}
	newID, err := store.WriteQA(ctx, newQA)
	if err != nil {
		log.Fatalf("write QA: %v", err)
	}
	fmt.Printf("Inserted QA with ID: %d\n", newID)

	// 示例：更新 QA
	updateQA := QA{
		ID:       newID,
		UserID:   int64Ptr(101),
		Username: stringPtr("alice_updated"),
		Question: stringPtr("What is NLP?"),
		Answer:   stringPtr("NLP is..."),
	}
	updatedID, err := store.WriteQA(ctx, updateQA)
	if err != nil {
		log.Fatalf("update QA: %v", err)
	}
	fmt.Printf("Updated QA with ID: %d\n", updatedID)
}

// 辅助函数：处理指针类型的空值
func int64Ptr(i int64) *int64 {
	return &i
}

func stringPtr(s string) *string {
	return &s
}

func derefInt64(p *int64) interface{} {
	if p == nil {
		return nil
	}
	return *p
}

func derefString(p *string) interface{} {
	if p == nil {
		return nil
	}
	return *p
}
