package executor

import (
	"fmt"
	"regexp"
	"strings"
)

// WhereCondition 表示解析后的 WHERE 条件
type WhereCondition struct {
	SQL  string        // 参数化的 SQL 片段
	Args []interface{} // 参数值列表
}

// WhereParser 安全的 WHERE 子句解析器
type WhereParser struct {
	allowedColumns map[string]bool // 允许的列名白名单
	paramCounter   int             // 参数计数器
}

// NewWhereParser 创建新的解析器
func NewWhereParser(allowedColumns []string) *WhereParser {
	colMap := make(map[string]bool)
	for _, col := range allowedColumns {
		colMap[strings.ToLower(col)] = true
	}
	return &WhereParser{
		allowedColumns: colMap,
		paramCounter:   1,
	}
}

// 支持的操作符
var allowedOperators = map[string]bool{
	"=":           true,
	"!=":          true,
	"<>":          true,
	">":           true,
	"<":           true,
	">=":          true,
	"<=":          true,
	"LIKE":        true,
	"ILIKE":       true,
	"IN":          true,
	"NOT IN":      true,
	"IS NULL":     true,
	"IS NOT NULL": true,
}

// Parse 解析 WHERE 子句
// 输入格式示例：
//
//	"id = 1"
//	"name = 'John' AND age > 18"
//	"status IN ('active', 'pending')"
//	"email IS NOT NULL"
func (p *WhereParser) Parse(whereClause string) (*WhereCondition, error) {
	if strings.TrimSpace(whereClause) == "" {
		return &WhereCondition{SQL: "", Args: []interface{}{}}, nil
	}

	// 简单的词法分析：按 AND/OR 分割
	tokens := p.tokenize(whereClause)

	var sqlParts []string
	var args []interface{}

	for _, token := range tokens {
		token = strings.TrimSpace(token)

		// 处理逻辑操作符
		upperToken := strings.ToUpper(token)
		if upperToken == "AND" || upperToken == "OR" {
			sqlParts = append(sqlParts, upperToken)
			continue
		}

		// 处理括号
		if token == "(" || token == ")" {
			sqlParts = append(sqlParts, token)
			continue
		}

		// 解析单个条件
		condSQL, condArgs, err := p.parseCondition(token)
		if err != nil {
			return nil, fmt.Errorf("parse condition '%s' failed: %w", token, err)
		}

		sqlParts = append(sqlParts, condSQL)
		args = append(args, condArgs...)
	}

	return &WhereCondition{
		SQL:  strings.Join(sqlParts, " "),
		Args: args,
	}, nil
}

// tokenize 简单的词法分析
func (p *WhereParser) tokenize(input string) []string {
	var tokens []string
	var current strings.Builder
	inQuote := false
	quoteChar := rune(0)
	parenDepth := 0

	for i, ch := range input {
		switch {
		case (ch == '\'' || ch == '"') && (i == 0 || input[i-1] != '\\'):
			if !inQuote {
				inQuote = true
				quoteChar = ch
			} else if ch == quoteChar {
				inQuote = false
				quoteChar = 0
			}
			current.WriteRune(ch)

		case ch == '(' && !inQuote:
			if current.Len() > 0 {
				tokens = append(tokens, current.String())
				current.Reset()
			}
			tokens = append(tokens, "(")
			parenDepth++

		case ch == ')' && !inQuote:
			if current.Len() > 0 {
				tokens = append(tokens, current.String())
				current.Reset()
			}
			tokens = append(tokens, ")")
			parenDepth--

		case ch == ' ' && !inQuote && parenDepth == 0:
			if current.Len() > 0 {
				word := current.String()
				upperWord := strings.ToUpper(word)

				// 检查是否是逻辑操作符
				if upperWord == "AND" || upperWord == "OR" {
					tokens = append(tokens, word)
					current.Reset()
				} else {
					current.WriteRune(ch)
				}
			}

		default:
			current.WriteRune(ch)
		}
	}

	if current.Len() > 0 {
		tokens = append(tokens, current.String())
	}

	return tokens
}

// parseCondition 解析单个条件表达式
// 例如: "id = 1", "name LIKE 'John%'", "status IN ('a', 'b')"
func (p *WhereParser) parseCondition(condition string) (string, []interface{}, error) {
	condition = strings.TrimSpace(condition)

	// 处理 IS NULL / IS NOT NULL
	if matched, _ := regexp.MatchString(`(?i)\s+IS\s+NOT\s+NULL\s*$`, condition); matched {
		column := regexp.MustCompile(`(?i)\s+IS\s+NOT\s+NULL\s*$`).ReplaceAllString(condition, "")
		column = strings.TrimSpace(column)
		if err := p.validateColumn(column); err != nil {
			return "", nil, err
		}
		return fmt.Sprintf("%s IS NOT NULL", QuoteIdentifier(column)), []interface{}{}, nil
	}

	if matched, _ := regexp.MatchString(`(?i)\s+IS\s+NULL\s*$`, condition); matched {
		column := regexp.MustCompile(`(?i)\s+IS\s+NULL\s*$`).ReplaceAllString(condition, "")
		column = strings.TrimSpace(column)
		if err := p.validateColumn(column); err != nil {
			return "", nil, err
		}
		return fmt.Sprintf("%s IS NULL", QuoteIdentifier(column)), []interface{}{}, nil
	}

	// 处理 IN / NOT IN
	inPattern := regexp.MustCompile(`(?i)^(\w+)\s+(NOT\s+)?IN\s*\((.+)\)$`)
	if matches := inPattern.FindStringSubmatch(condition); matches != nil {
		column := matches[1]
		notIn := strings.TrimSpace(matches[2]) != ""
		valuesPart := matches[3]

		if err := p.validateColumn(column); err != nil {
			return "", nil, err
		}

		// 解析 IN 列表中的值
		values := p.parseInValues(valuesPart)
		if len(values) == 0 {
			return "", nil, fmt.Errorf("empty IN clause")
		}

		placeholders := make([]string, len(values))
		args := make([]interface{}, len(values))
		for i, v := range values {
			placeholders[i] = fmt.Sprintf("$%d", p.paramCounter)
			args[i] = v
			p.paramCounter++
		}

		operator := "IN"
		if notIn {
			operator = "NOT IN"
		}

		sql := fmt.Sprintf("%s %s (%s)", QuoteIdentifier(column), operator, strings.Join(placeholders, ", "))
		return sql, args, nil
	}

	// 处理标准比较操作符
	for op := range allowedOperators {
		if op == "IN" || op == "NOT IN" || op == "IS NULL" || op == "IS NOT NULL" {
			continue // 已经处理过
		}

		pattern := regexp.MustCompile(fmt.Sprintf(`(?i)^(\w+)\s+%s\s+(.+)$`, regexp.QuoteMeta(op)))
		if matches := pattern.FindStringSubmatch(condition); matches != nil {
			column := matches[1]
			value := strings.TrimSpace(matches[2])

			if err := p.validateColumn(column); err != nil {
				return "", nil, err
			}

			// 去除值的引号
			value = p.unquoteValue(value)

			sql := fmt.Sprintf("%s %s $%d", QuoteIdentifier(column), strings.ToUpper(op), p.paramCounter)
			p.paramCounter++
			return sql, []interface{}{value}, nil
		}
	}

	return "", nil, fmt.Errorf("invalid condition format: %s", condition)
}

// validateColumn 验证列名是否在白名单中
func (p *WhereParser) validateColumn(column string) error {
	column = strings.ToLower(strings.TrimSpace(column))
	if !p.allowedColumns[column] {
		return fmt.Errorf("column '%s' not allowed in WHERE clause", column)
	}
	return nil
}

// parseInValues 解析 IN 子句中的值列表
func (p *WhereParser) parseInValues(valuesPart string) []string {
	var values []string
	var current strings.Builder
	inQuote := false
	quoteChar := rune(0)

	for i, ch := range valuesPart {
		switch {
		case (ch == '\'' || ch == '"') && (i == 0 || valuesPart[i-1] != '\\'):
			if !inQuote {
				inQuote = true
				quoteChar = ch
			} else if ch == quoteChar {
				inQuote = false
				quoteChar = 0
			}
			current.WriteRune(ch)

		case ch == ',' && !inQuote:
			if current.Len() > 0 {
				values = append(values, p.unquoteValue(strings.TrimSpace(current.String())))
				current.Reset()
			}

		default:
			current.WriteRune(ch)
		}
	}

	if current.Len() > 0 {
		values = append(values, p.unquoteValue(strings.TrimSpace(current.String())))
	}

	return values
}

// unquoteValue 去除值的引号
func (p *WhereParser) unquoteValue(value string) string {
	value = strings.TrimSpace(value)
	if len(value) >= 2 {
		if (value[0] == '\'' && value[len(value)-1] == '\'') ||
			(value[0] == '"' && value[len(value)-1] == '"') {
			return value[1 : len(value)-1]
		}
	}
	return value
}
