package database

import (
	"context"
	"fmt"
	"log"

	"github.com/jackc/pgx/v5"
	"github.com/jackc/pgx/v5/pgxpool"
)

type DB struct {
	pool *pgxpool.Pool
}

func New(ctx context.Context, dsn string) (*DB, error) {
	config, err := pgxpool.ParseConfig(dsn)
	if err != nil {
		return nil, fmt.Errorf("failed to parse dsn: %w", err)
	}

	pool, err := pgxpool.NewWithConfig(ctx, config)
	if err != nil {
		return nil, fmt.Errorf("failed to connect to db: %w", err)
	}

	if err := pool.Ping(ctx); err != nil {
		return nil, fmt.Errorf("failed to ping db: %w", err)
	}

	db := &DB{pool: pool}
	if err := db.initSchema(ctx); err != nil {
		return nil, err
	}

	return db, nil
}

func (d *DB) Close() {
	d.pool.Close()
}

// Ping 检查数据库连接
func (d *DB) Ping(ctx context.Context) error {
	return d.pool.Ping(ctx)
}

// Exec 执行通用 SQL
func (d *DB) Exec(ctx context.Context, sql string, args ...interface{}) error {
	_, err := d.pool.Exec(ctx, sql, args...)
	return err
}

// initSchema 初始化系统表
func (d *DB) initSchema(ctx context.Context) error {
	// 1. 合约实例注册表
	// 2. 游标记录表
	sql := `
    CREATE TABLE IF NOT EXISTS _chainsql_instances (
        contract_address VARCHAR(42) PRIMARY KEY,
        owner_address VARCHAR(42),
        created_at_block BIGINT,
        status VARCHAR(20) DEFAULT 'active'
    );
    
    CREATE TABLE IF NOT EXISTS _chainsql_cursor (
        sync_key VARCHAR(50) PRIMARY KEY,
        last_block BIGINT,
        updated_at TIMESTAMP
    );
    `
	_, err := d.pool.Exec(ctx, sql)
	if err != nil {
		return fmt.Errorf("failed to init schema: %w", err)
	}
	log.Println("Database schema initialized")
	return nil
}

// RegisterInstance 注册新发现的合约实例
func (d *DB) RegisterInstance(ctx context.Context, address string, owner string, blockNumber uint64) error {
	sql := `
			INSERT INTO _chainsql_instances (contract_address, owner_address, created_at_block, status)
			VALUES ($1, $2, $3, 'active')
			ON CONFLICT (contract_address) DO NOTHING
	`
	_, err := d.pool.Exec(ctx, sql, address, owner, blockNumber)
	return err
}

// GetLastBlock 获取指定 key 的同步高度
func (d *DB) GetLastBlock(ctx context.Context, key string) (uint64, error) {
	var blockNum uint64
	sql := `SELECT last_block FROM _chainsql_cursor WHERE sync_key = $1`
	err := d.pool.QueryRow(ctx, sql, key).Scan(&blockNum)
	if err == pgx.ErrNoRows {
		return 0, nil
	}
	if err != nil {
		return 0, err
	}
	return blockNum, nil
}

// UpdateLastBlock 更新同步高度
func (d *DB) UpdateLastBlock(ctx context.Context, key string, blockNum uint64) error {
	sql := `
			INSERT INTO _chainsql_cursor (sync_key, last_block, updated_at)
			VALUES ($1, $2, NOW())
			ON CONFLICT (sync_key) 
			DO UPDATE SET last_block = $2, updated_at = NOW()
	`
	_, err := d.pool.Exec(ctx, sql, key, blockNum)
	return err
}

func (d *DB) GetActiveInstances(ctx context.Context) ([]string, error) {
	sql := `SELECT contract_address FROM _chainsql_instances WHERE status = 'active'`
	rows, err := d.pool.Query(ctx, sql)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var addrs []string
	for rows.Next() {
		var addr string
		if err := rows.Scan(&addr); err != nil {
			return nil, err
		}
		addrs = append(addrs, addr)
	}
	return addrs, nil
}
