package main

import (
	"context"
	"database/sql"
	"flag"
	"fmt"
	"time"

	_ "github.com/lib/pq"
)

var (
	connString = flag.String("dbconn", "", "database connection string")
)

type QA struct {
	ID        int64     // 主键
	CreatedAt time.Time // 创建时间
	FromID    *string   // 可空的 from_id
	From      *string   // 可空的 from
	Question  *string   // 可空的问题
	Answer    *string   // 可空的答案
	Summary   *string   // 可空的摘要
	To        *string   // 可空的 to
	ToID      *string   // 可空的 to_id
}

// QAStore 定义 DAO 接口
type QAStore interface {
	// GetLatestQA 从 qa_latest_from_id 视图读取指定 from_id 的最新记录
	GetLatestQA(ctx context.Context, fromID *string) ([]QA, error)
	// WriteQA 插入或更新 qa 表记录
	WriteQA(ctx context.Context, qa QA) (int64, error)
}

// qaStore 是 QAStore 接口的实现
type qaStore struct {
	db *sql.DB
}

// NewQAStore 创建新的 QAStore 实例
func NewQAStore(db *sql.DB) QAStore {
	return &qaStore{db: db}
}

// 初始化数据库连接并返回 QAStore
func InitQAStore() (QAStore, error) {
	// Supabase 提供的连接字符串
	connString := "postgresql://postgres.awcfgdodiuqnlsobcivq:P99IU9NEoDRPsBfb@aws-0-ap-southeast-1.pooler.supabase.com:5432/postgres"

	// 打开数据库连接
	db, err := sql.Open("postgres", connString)
	if err != nil {
		return nil, fmt.Errorf("open database: %w", err)
	}

	// 测试数据库连接
	if err := db.Ping(); err != nil {
		db.Close()
		return nil, fmt.Errorf("ping database: %w", err)
	}

	// 返回 QAStore 实例
	return NewQAStore(db), nil
}

func (s *qaStore) GetLatestQA(ctx context.Context, fromID *string) ([]QA, error) {
	query := `
        SELECT id, created_at, question, answer, summary, "from", "to", from_id, to_id
        FROM qa_latest_from_id
        WHERE from_id = $1 OR (from_id IS NULL AND $1 IS NULL)`
	args := []interface{}{fromID}
	if fromID == nil {
		args = []interface{}{nil}
	}

	rows, err := s.db.QueryContext(ctx, query, args...)
	if err != nil {
		return nil, fmt.Errorf("query qa_latest_from_id: %w", err)
	}
	defer rows.Close()

	var results []QA
	for rows.Next() {
		var qa QA
		var question, answer, summary, from, to, fromIDVal, toIDVal sql.NullString
		if err := rows.Scan(&qa.ID, &qa.CreatedAt, &question, &answer, &summary, &from, &to, &fromIDVal, &toIDVal); err != nil {
			return nil, fmt.Errorf("scan row: %w", err)
		}
		if question.Valid {
			qa.Question = &question.String
		}
		if answer.Valid {
			qa.Answer = &answer.String
		}
		if summary.Valid {
			qa.Summary = &summary.String
		}
		if from.Valid {
			qa.From = &from.String
		}
		if to.Valid {
			qa.To = &to.String
		}
		if fromIDVal.Valid {
			qa.FromID = &fromIDVal.String
		}
		if toIDVal.Valid {
			qa.ToID = &toIDVal.String
		}
		results = append(results, qa)
	}
	if err := rows.Err(); err != nil {
		return nil, fmt.Errorf("row iteration: %w", err)
	}
	return results, nil
}

func (s *qaStore) WriteQA(ctx context.Context, qa QA) (int64, error) {
	if qa.ID != 0 {
		// 更新记录
		query := `
            UPDATE qa
            SET question = $1, answer = $2, summary = $3, "from" = $4, "to" = $5, from_id = $6, to_id = $7
            WHERE id = $8
            RETURNING id`
		var updatedID int64
		err := s.db.QueryRowContext(ctx, query,
			derefString(qa.Question),
			derefString(qa.Answer),
			derefString(qa.Summary),
			derefString(qa.From),
			derefString(qa.To),
			derefString(qa.FromID),
			derefString(qa.ToID),
			qa.ID,
		).Scan(&updatedID)
		if err == sql.ErrNoRows {
			return 0, fmt.Errorf("no record found with id %d", qa.ID)
		}
		if err != nil {
			return 0, fmt.Errorf("update qa: %w", err)
		}
		return updatedID, nil
	}

	// 插入新记录
	query := `
        INSERT INTO qa (question, answer, summary, "from", "to", from_id, to_id)
        VALUES ($1, $2, $3, $4, $5, $6, $7)
        RETURNING id`
	var newID int64
	err := s.db.QueryRowContext(ctx, query,
		derefString(qa.Question),
		derefString(qa.Answer),
		derefString(qa.Summary),
		derefString(qa.From),
		derefString(qa.To),
		derefString(qa.FromID),
		derefString(qa.ToID),
	).Scan(&newID)
	if err != nil {
		return 0, fmt.Errorf("insert qa: %w", err)
	}
	return newID, nil
}

// 辅助函数：处理指针类型的空值
func stringPtr(s string) *string {
	return &s
}

func derefString(p *string) interface{} {
	if p == nil {
		return nil
	}
	return *p
}
