package executor

import (
	"context"
	"fmt"
	"log"
	"strings"

	"chain-sql/internal/database"
)

// SQL 类型白名单 (PostgreSQL 常用类型)
// 凡是不在这个列表里的类型定义，HandleTableCreated 都会拒绝执行
var allowedSqlTypes = map[string]bool{
	"int": true, "integer": true, "bigint": true, "smallint": true,
	"text": true, "varchar": true, "char": true, "character varying": true,
	"boolean": true, "bool": true,
	"decimal": true, "numeric": true, "real": true, "double precision": true,
	"timestamp": true, "date": true, "time": true,
	"json": true, "jsonb": true,
	"bytea": true,
}

type Executor struct {
	db *database.DB
}

func NewExecutor(db *database.DB) *Executor {
	return &Executor{db: db}
}

// validateSqlType 校验 SQL 类型是否合法
func validateSqlType(t string) bool {
	// 移除长度修饰符，例如 "VARCHAR(50)" -> "varchar"
	parts := strings.Split(t, "(")
	baseType := strings.TrimSpace(strings.ToLower(parts[0]))
	return allowedSqlTypes[baseType]
}

// EnsureSchema 为每个合约创建独立的 Schema
func (e *Executor) EnsureSchema(ctx context.Context, contractAddr string) error {
	schemaName := QuoteIdentifier(contractAddr)
	sql := fmt.Sprintf(`CREATE SCHEMA IF NOT EXISTS %s`, schemaName)
	return e.db.Exec(ctx, sql)
}

// HandleTableCreated 处理建表事件 (含安全校验)
func (e *Executor) HandleTableCreated(ctx context.Context, contractAddr string, tableName string, columns []struct {
	Name         string
	SqlType      string
	IsPrimaryKey bool
}) error {
	schemaName := QuoteIdentifier(contractAddr)
	safeTableName := QuoteIdentifier(tableName)

	var colDefs []string
	var pks []string

	for _, col := range columns {
		// --- 安全检查核心逻辑 ---
		if !validateSqlType(col.SqlType) {
			return fmt.Errorf("security error: invalid or forbidden sql type '%s' for column '%s'", col.SqlType, col.Name)
		}
		// ---------------------

		def := fmt.Sprintf("%s %s", QuoteIdentifier(col.Name), col.SqlType)
		colDefs = append(colDefs, def)

		if col.IsPrimaryKey {
			pks = append(pks, QuoteIdentifier(col.Name))
		}
	}

	if len(pks) > 0 {
		pkDef := fmt.Sprintf("PRIMARY KEY (%s)", strings.Join(pks, ", "))
		colDefs = append(colDefs, pkDef)
	}

	fullSql := fmt.Sprintf(
		`CREATE TABLE IF NOT EXISTS %s.%s (%s)`,
		schemaName,
		safeTableName,
		strings.Join(colDefs, ", "),
	)

	log.Printf("Executing DDL: %s", fullSql)
	return e.db.Exec(ctx, fullSql)
}

// HandleDataInserted 处理插入事件
func (e *Executor) HandleDataInserted(ctx context.Context, contractAddr string, tableName string, columns []string, values []string) error {
	schemaName := QuoteIdentifier(contractAddr)
	safeTableName := QuoteIdentifier(tableName)

	// 构建占位符 ($1, $2...)
	placeholders := make([]string, len(values))
	args := make([]interface{}, len(values))
	safeCols := make([]string, len(columns))

	for i, v := range values {
		placeholders[i] = fmt.Sprintf("$%d", i+1)
		args[i] = v
	}
	for i, c := range columns {
		safeCols[i] = QuoteIdentifier(c)
	}

	sql := fmt.Sprintf(
		`INSERT INTO %s.%s (%s) VALUES (%s)`,
		schemaName,
		safeTableName,
		strings.Join(safeCols, ", "),
		strings.Join(placeholders, ", "),
	)

	return e.db.Exec(ctx, sql, args...)
}

// QuoteIdentifier 对标识符加双引号，防止 SQL 注入和关键字冲突
func QuoteIdentifier(s string) string {
	return fmt.Sprintf(`"%s"`, strings.ReplaceAll(s, `"`, `""`))
}
