package executor

import (
	"context"
	"fmt"
	"log"
	"strings"

	"chain-sql/internal/database"
)

// SQL 类型白名单 (PostgreSQL 常用类型)
// 凡是不在这个列表里的类型定义，HandleTableCreated 都会拒绝执行
var allowedSqlTypes = map[string]bool{
	"int": true, "integer": true, "bigint": true, "smallint": true,
	"text": true, "varchar": true, "char": true, "character varying": true,
	"boolean": true, "bool": true,
	"decimal": true, "numeric": true, "real": true, "double precision": true,
	"timestamp": true, "date": true, "time": true,
	"json": true, "jsonb": true,
	"bytea": true,
}

type Executor struct {
	db *database.DB
}

func NewExecutor(db *database.DB) *Executor {
	return &Executor{db: db}
}

// validateSqlType 校验 SQL 类型是否合法
func validateSqlType(t string) bool {
	// 移除长度修饰符，例如 "VARCHAR(50)" -> "varchar"
	parts := strings.Split(t, "(")
	baseType := strings.TrimSpace(strings.ToLower(parts[0]))
	return allowedSqlTypes[baseType]
}

// EnsureSchema 为每个合约创建独立的 Schema
func (e *Executor) EnsureSchema(ctx context.Context, contractAddr string) error {
	schemaName := QuoteIdentifier(contractAddr)
	sql := fmt.Sprintf(`CREATE SCHEMA IF NOT EXISTS %s`, schemaName)
	return e.db.Exec(ctx, sql)
}

// HandleTableCreated 处理建表事件 (含安全校验)
func (e *Executor) HandleTableCreated(ctx context.Context, contractAddr string, tableName string, columns []struct {
	Name         string
	SqlType      string
	IsPrimaryKey bool
}) error {
	schemaName := QuoteIdentifier(contractAddr)
	safeTableName := QuoteIdentifier(tableName)

	var colDefs []string
	var pks []string

	for _, col := range columns {
		// --- 安全检查核心逻辑 ---
		if !validateSqlType(col.SqlType) {
			return fmt.Errorf("security error: invalid or forbidden sql type '%s' for column '%s'", col.SqlType, col.Name)
		}
		// ---------------------

		def := fmt.Sprintf("%s %s", QuoteIdentifier(col.Name), col.SqlType)
		colDefs = append(colDefs, def)

		if col.IsPrimaryKey {
			pks = append(pks, QuoteIdentifier(col.Name))
		}
	}

	if len(pks) > 0 {
		pkDef := fmt.Sprintf("PRIMARY KEY (%s)", strings.Join(pks, ", "))
		colDefs = append(colDefs, pkDef)
	}

	// ⭐ 先创建 Schema（如果不存在）
	createSchemaSql := fmt.Sprintf(`CREATE SCHEMA IF NOT EXISTS %s`, schemaName)
	log.Printf("Ensuring schema exists: %s", createSchemaSql)
	if err := e.db.Exec(ctx, createSchemaSql); err != nil {
		return fmt.Errorf("create schema failed: %w", err)
	}

	// 再创建表
	fullSql := fmt.Sprintf(
		`CREATE TABLE IF NOT EXISTS %s.%s (%s)`,
		schemaName,
		safeTableName,
		strings.Join(colDefs, ", "),
	)

	log.Printf("Executing DDL: %s", fullSql)
	if err := e.db.Exec(ctx, fullSql); err != nil {
		return fmt.Errorf("create table failed: %w", err)
	}

	return nil
}

// HandleDataInserted 处理插入事件
func (e *Executor) HandleDataInserted(ctx context.Context, contractAddr string, tableName string, columns []string, values []string) error {
	schemaName := QuoteIdentifier(contractAddr)
	safeTableName := QuoteIdentifier(tableName)

	// 构建占位符 ($1, $2...)
	placeholders := make([]string, len(values))
	args := make([]interface{}, len(values))
	safeCols := make([]string, len(columns))

	for i, v := range values {
		placeholders[i] = fmt.Sprintf("$%d", i+1)
		args[i] = v
	}
	for i, c := range columns {
		safeCols[i] = QuoteIdentifier(c)
	}

	sql := fmt.Sprintf(
		`INSERT INTO %s.%s (%s) VALUES (%s)`,
		schemaName,
		safeTableName,
		strings.Join(safeCols, ", "),
		strings.Join(placeholders, ", "),
	)

	return e.db.Exec(ctx, sql, args...)
}

// QuoteIdentifier 对标识符加双引号，防止 SQL 注入和关键字冲突
func QuoteIdentifier(s string) string {
	return fmt.Sprintf(`"%s"`, strings.ReplaceAll(s, `"`, `""`))
}

// HandleDataUpdated 处理更新事件
func (e *Executor) HandleDataUpdated(ctx context.Context, contractAddr string, tableName string, setColumns []string, setValues []string, whereClause string) error {
	if len(setColumns) != len(setValues) {
		return fmt.Errorf("columns and values length mismatch: %d vs %d", len(setColumns), len(setValues))
	}

	schemaName := QuoteIdentifier(contractAddr)
	safeTableName := QuoteIdentifier(tableName)

	// 构建 SET 子句
	var setParts []string
	args := make([]interface{}, 0, len(setValues))

	for i, col := range setColumns {
		setParts = append(setParts, fmt.Sprintf("%s = $%d", QuoteIdentifier(col), i+1))
		args = append(args, setValues[i])
	}

	// 解析 WHERE 子句
	// ⭐ 使用 setColumns 作为基础，但需要包含可能在 WHERE 中使用的其他列
	// 临时方案：允许所有常见列名，生产环境应该查询表结构
	allowedCols := make([]string, 0)
	allowedCols = append(allowedCols, setColumns...)
	// 添加常见的列名
	commonCols := []string{"id", "name", "status", "email", "price", "stock", "created_at", "updated_at"}
	for _, col := range commonCols {
		// 避免重复
		found := false
		for _, existing := range allowedCols {
			if strings.EqualFold(existing, col) {
				found = true
				break
			}
		}
		if !found {
			allowedCols = append(allowedCols, col)
		}
	}

	parser := NewWhereParser(allowedCols)

	whereCond, err := parser.Parse(whereClause)
	if err != nil {
		return fmt.Errorf("parse WHERE clause failed: %w", err)
	}

	// ⭐ 调整 WHERE 子句的参数占位符编号
	// WHERE 解析器生成的占位符从 $1 开始，但我们需要将它们调整为从 len(setValues)+1 开始
	adjustedWhereSQL := whereCond.SQL
	setParamCount := len(setValues)

	// 从后往前替换，避免冲突（例如 $10 被误替换为 $1）
	for i := len(whereCond.Args); i >= 1; i-- {
		oldPlaceholder := fmt.Sprintf("$%d", i)
		newPlaceholder := fmt.Sprintf("$%d", setParamCount+i)
		adjustedWhereSQL = strings.ReplaceAll(adjustedWhereSQL, oldPlaceholder, newPlaceholder)
	}

	// 追加 WHERE 参数
	args = append(args, whereCond.Args...)

	sql := fmt.Sprintf(
		`UPDATE %s.%s SET %s WHERE %s`,
		schemaName,
		safeTableName,
		strings.Join(setParts, ", "),
		adjustedWhereSQL,
	)

	log.Printf("Executing UPDATE: %s", sql)
	return e.db.Exec(ctx, sql, args...)
}

// HandleDataDeleted 处理删除事件
func (e *Executor) HandleDataDeleted(ctx context.Context, contractAddr string, tableName string, whereClause string) error {
	schemaName := QuoteIdentifier(contractAddr)
	safeTableName := QuoteIdentifier(tableName)

	// 解析 WHERE 子句
	// ⭐ 允许所有常见列名，生产环境应该查询 information_schema 获取实际列
	commonCols := []string{"id", "name", "status", "email", "price", "stock", "created_at", "updated_at", "type", "category"}
	parser := NewWhereParser(commonCols)

	whereCond, err := parser.Parse(whereClause)
	if err != nil {
		return fmt.Errorf("parse WHERE clause failed: %w", err)
	}

	if whereCond.SQL == "" {
		return fmt.Errorf("DELETE without WHERE clause is not allowed for safety")
	}

	sql := fmt.Sprintf(
		`DELETE FROM %s.%s WHERE %s`,
		schemaName,
		safeTableName,
		whereCond.SQL,
	)

	log.Printf("Executing DELETE: %s", sql)
	return e.db.Exec(ctx, sql, whereCond.Args...)
}

// HandleBatchInserted 处理批量插入事件
func (e *Executor) HandleBatchInserted(ctx context.Context, contractAddr string, tableName string, columns []string, rows [][]string) error {
	if len(rows) == 0 {
		return nil
	}

	schemaName := QuoteIdentifier(contractAddr)
	safeTableName := QuoteIdentifier(tableName)
	safeCols := make([]string, len(columns))

	for i, c := range columns {
		safeCols[i] = QuoteIdentifier(c)
	}

	// 构建批量插入 SQL
	var valueParts []string
	args := make([]interface{}, 0, len(rows)*len(columns))
	paramIndex := 1

	for _, row := range rows {
		if len(row) != len(columns) {
			return fmt.Errorf("row length mismatch: expected %d, got %d", len(columns), len(row))
		}

		placeholders := make([]string, len(row))
		for i, val := range row {
			placeholders[i] = fmt.Sprintf("$%d", paramIndex)
			args = append(args, val)
			paramIndex++
		}
		valueParts = append(valueParts, fmt.Sprintf("(%s)", strings.Join(placeholders, ", ")))
	}

	sql := fmt.Sprintf(
		`INSERT INTO %s.%s (%s) VALUES %s`,
		schemaName,
		safeTableName,
		strings.Join(safeCols, ", "),
		strings.Join(valueParts, ", "),
	)

	log.Printf("Executing BATCH INSERT: %d rows into %s.%s", len(rows), contractAddr, tableName)
	return e.db.Exec(ctx, sql, args...)
}
