package main

import (
	"bytes"
	"database/sql"
	"encoding/json"
	"io"
	"log"
	"net/http"
	"os"
	"strings"
	"time"

	"bufio"
	"fmt"
	"net/url"
	"path/filepath"
	"regexp"
	"sync"

	"github.com/ethereum/go-ethereum/common"
	"github.com/ethereum/go-ethereum/common/hexutil"
	"github.com/ethereum/go-ethereum/core/types"
	"github.com/ethereum/go-ethereum/ethclient"
	"github.com/fsnotify/fsnotify"
	_ "github.com/go-sql-driver/mysql"
)

type TbAccountInfo struct {
	Id             int64     `json:"id"`
	BlockId        int64     `json:"block_id"`
	BlockHash      string    `json:"block_hash"`
	TxHash         string    `json:"tx_hash"`
	AccountAddress string    `json:"account_address"`
	AccountType    int       `json:"account_type"`
	MyNameTag      string    `json:"my_name_tag"`
	Balance        float64   `json:"balance"`
	Status         int       `json:"status"`
	IsDeleted      int8      `json:"is_deleted"`
	SyncTime       time.Time `json:"sync_time"`
	CreateTime     time.Time `json:"create_time"`
	UpdateTime     time.Time `json:"update_time"`
}

func (t *TbAccountInfo) TableName() string {
	return "tb_account_info"
}

// RPC request structure
// Only handle params for eth_getBalance here
// Forward other methods directly

type RPCRequest struct {
	Jsonrpc string        `json:"jsonrpc"`
	Method  string        `json:"method"`
	Params  []interface{} `json:"params"`
	Id      interface{}   `json:"id"`
}

type RPCResponse struct {
	Jsonrpc string      `json:"jsonrpc"`
	Id      interface{} `json:"id"`
	Result  interface{} `json:"result"`
	Error   interface{} `json:"error,omitempty"`
}

var (
	db                *sql.DB
	rpcBackend        = os.Getenv("ETH_RPC_BACKEND") // Real Ethereum RPC address, recommend using environment variable
	ethClient         *ethclient.Client
	blacklistContract common.Address
	// whitelist related
	whitelistFile     string
	whitelist         map[string]struct{}
	whitelistPatterns []*regexp.Regexp
	whitelistMu       sync.RWMutex
)

func main() {
	// Initialize database connection
	var err error
	dsn := os.Getenv("MYSQL_DSN") // Example: "user:password@tcp(127.0.0.1:3306)/dbname"
	db, err = sql.Open("mysql", dsn)
	if err != nil {
		log.Fatalf("Database connection failed: %v", err)
	}
	defer func() {
		if err := db.Close(); err != nil {
			log.Printf("db close error: %v", err)
		}
	}()

	// Initialize eth client for blacklist checks if backend provided
	if rpcBackend != "" {
		ehtCli, err := ethclient.Dial(rpcBackend)
		if err != nil {
			log.Fatalf("failed to create eth client: %v", err)
		}
		ethClient = ehtCli
	} else {
		log.Printf("ETH_RPC_BACKEND not set, blacklist checks will be disabled")
	}

	// Setup blacklist contract address from env if provided
	if addr := os.Getenv("BLACKLIST_CONTRACT_ADDR"); addr != "" {
		blacklistContract = common.HexToAddress(addr)
	} else {
		// leave zero address; checks will be skipped
		log.Printf("BLACKLIST_CONTRACT_ADDR not set, blacklist checks will be disabled")
	}

	// Load whitelist file and start watcher
	whitelistFile = os.Getenv("WHITELIST_FILE")
	if whitelistFile != "" {
		wlExact, wlPatterns := loadWhitelist(whitelistFile)
		whitelistMu.Lock()
		whitelist = wlExact
		whitelistPatterns = wlPatterns
		whitelistMu.Unlock()
		log.Printf("loaded whitelist entries: %d exact, %d patterns", len(wlExact), len(wlPatterns))
		startWhitelistWatcher(whitelistFile)
	} else {
		whitelist = map[string]struct{}{}
		whitelistPatterns = []*regexp.Regexp{}
		log.Printf("WHITELIST_FILE not set, whitelist feature disabled")
	}

	// Start cache janitor for blacklist cache. Interval can be configured via env BLACKLIST_CACHE_CLEANUP_INTERVAL (e.g. "5m").
	cleanupInterval := 5 * time.Minute
	if s := os.Getenv("BLACKLIST_CACHE_CLEANUP_INTERVAL"); s != "" {
		if d, err := time.ParseDuration(s); err == nil {
			cleanupInterval = d
		} else {
			log.Printf("invalid BLACKLIST_CACHE_CLEANUP_INTERVAL '%s', using default %s", s, cleanupInterval)
		}
	}
	startBlacklistCacheJanitor(cleanupInterval)

	http.HandleFunc("/", proxyHandler)
	log.Println("RPC proxy service started, listening on port: 8545")
	log.Fatal(http.ListenAndServe(":8545", nil))
}

func proxyHandler(w http.ResponseWriter, r *http.Request) {
	// Add CORS support
	setCORSHeaders(w, r)

	// Handle preflight requests
	if r.Method == "OPTIONS" {
		w.WriteHeader(http.StatusOK)
		return
	}

	body, err := io.ReadAll(r.Body)
	if err != nil {
		http.Error(w, "Failed to read request", http.StatusBadRequest)
		return
	}
	defer func() {
		if err := r.Body.Close(); err != nil {
			log.Printf("request body close error: %v", err)
		}
	}()

	// Whitelist short-circuit: if Origin or Referer matches whitelist, forward immediately
	origin := r.Header.Get("Origin")
	referer := r.Header.Get("Referer") // typical header key
	if isWhitelisted(origin) || isWhitelisted(referer) {
		log.Printf("whitelist matched (origin=%s referer=%s), forwarding directly", origin, referer)
		forwardToBackend(w, body)
		return
	}

	var reqs []RPCRequest
	// Try to parse as batch request first
	if err := json.Unmarshal(body, &reqs); err == nil {
		// Handle batch request
		if len(reqs) > 1 {
			req := reqs[0]
			resp := RPCResponse{
				Jsonrpc: req.Jsonrpc,
				Id:      req.Id,
				Result:  "invalid batch request",
			}
			xForwardedFor := r.Header.Get("X-Forwarded-For")
			log.Printf("stop forward to rpc on batch request, request from X-Forwarded-For: %s, param[0]: %v", xForwardedFor, req.Params)
			w.Header().Set("Content-Type", "application/json")
			if err := json.NewEncoder(w).Encode(resp); err != nil {
				log.Printf("encode response error: %v", err)
			}
		}
		return
	}

	var req RPCRequest
	if err := json.Unmarshal(body, &req); err != nil {
		forwardToBackend(w, body)
		return
	}

	// Handle eth_getBalance (existing behavior)
	if req.Method == "eth_getBalance" && len(req.Params) > 0 {
		// get remote ip from header
		xForwardedFor := r.Header.Get("X-Forwarded-For")
		realIp := r.Header.Get("X-Real-IP")
		log.Printf("eth_getBalance request from %s, X-Real-IP: %s, X-Forwarded-For: %s, address: %v", r.RemoteAddr, realIp, xForwardedFor, req.Params[0])
		address, ok := req.Params[0].(string)
		if !ok {
			forwardToBackend(w, body)
			return
		}
		if !accountExists(strings.ToLower(address)) {
			resp := RPCResponse{
				Jsonrpc: req.Jsonrpc,
				Id:      req.Id,
				Result:  "0x0",
			}
			log.Printf("stop forward to rpc on eth_getBalance request from %s, X-Real-IP: %s, X-Forwarded-For: %s, address: %v", r.RemoteAddr, realIp, xForwardedFor, req.Params[0])
			w.Header().Set("Content-Type", "application/json")
			if err := json.NewEncoder(w).Encode(resp); err != nil {
				log.Printf("encode response error: %v", err)
			}
			return
		}
	}

	// First, special-case eth_sendRawTransaction: extract sender from raw tx
	if req.Method == "eth_sendRawTransaction" && len(req.Params) > 0 {
		if rawHex, ok := req.Params[0].(string); ok {
			if fromAddr, err := getSenderFromRawTx(rawHex); err == nil {
				// only perform blacklist check if eth client and blacklist contract are configured
				if ethClient != nil && blacklistContract != (common.Address{}) {
					inBlack, err := CachedIsInBlacklist(ethClient, blacklistContract, common.HexToAddress(strings.ToLower(fromAddr)))
					if err != nil {
						log.Printf("blacklist check failed for %s: %v", fromAddr, err)
						// fail open: forward
						forwardToBackend(w, body)
						return
					}
					if inBlack {
						errResp := RPCResponse{
							Jsonrpc: req.Jsonrpc,
							Id:      req.Id,
							Error: map[string]interface{}{
								"code":    -32000,
								"message": "sender is invalid",
							},
						}
						w.Header().Set("Content-Type", "application/json")
						if err := json.NewEncoder(w).Encode(errResp); err != nil {
							log.Printf("encode error: %v", err)
						}
						return
					}
				}
			} else {
				// If we couldn't decode the sender, just forward (fail open)
				log.Printf("failed to decode raw tx sender: %v", err)
				forwardToBackend(w, body)
				return
			}
		}
	}

	// New: check if request includes a `from` address (common for eth_sendTransaction, eth_call, eth_estimateGas)
	if from, ok := extractFromAddress(req); ok {
		// only perform blacklist check if eth client and blacklist contract are configured
		if ethClient != nil && blacklistContract != (common.Address{}) {
			inBlack, err := CachedIsInBlacklist(ethClient, blacklistContract, common.HexToAddress(strings.ToLower(from)))
			if err != nil {
				// on error, log and forward to backend (fail open)
				log.Printf("blacklist check failed for %s: %v", from, err)
				forwardToBackend(w, body)
				return
			}
			if inBlack {
				// return JSON-RPC error response indicating sender is blacklisted
				errResp := RPCResponse{
					Jsonrpc: req.Jsonrpc,
					Id:      req.Id,
					Error: map[string]interface{}{
						"code":    -32000,
						"message": "sender is blacklisted",
					},
				}
				w.Header().Set("Content-Type", "application/json")
				if err := json.NewEncoder(w).Encode(errResp); err != nil {
					log.Printf("encode error: %v", err)
				}
				return
			}
		}
	}

	// Forward other cases directly
	forwardToBackend(w, body)
}

// getSenderFromRawTx decodes a raw transaction hex (0x...) and returns the sender address as a hex string.
func getSenderFromRawTx(rawHex string) (string, error) {
	// strip 0x if present
	b, err := hexutil.Decode(rawHex)
	if err != nil {
		return "", err
	}
	var tx types.Transaction
	if err := tx.UnmarshalBinary(b); err != nil {
		return "", err
	}
	// determine signer
	var signer types.Signer
	if tx.ChainId() != nil {
		signer = types.LatestSignerForChainID(tx.ChainId())
	} else {
		signer = types.HomesteadSigner{}
	}
	from, err := types.Sender(signer, &tx)
	if err != nil {
		return "", err
	}
	return strings.ToLower(from.Hex()), nil
}

func setCORSHeaders(w http.ResponseWriter, r *http.Request) {
	origin := r.Header.Get("Origin")
	if origin != "" {
		w.Header().Set("Access-Control-Allow-Origin", origin)
	} else {
		w.Header().Set("Access-Control-Allow-Origin", "*")
	}

	w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
	w.Header().Set("Access-Control-Allow-Headers", "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, X-Real-IP, X-Forwarded-For")
	w.Header().Set("Access-Control-Allow-Credentials", "true")
	w.Header().Set("Access-Control-Max-Age", "86400")
}

// extractFromAddress inspects JSON-RPC params and returns the `from` address if present.
func extractFromAddress(req RPCRequest) (string, bool) {
	// Typical shapes: params[0] is an object with field "from"
	if len(req.Params) == 0 {
		return "", false
	}

	// Check first param
	if obj, ok := req.Params[0].(map[string]interface{}); ok {
		if f, ok2 := obj["from"].(string); ok2 && f != "" {
			return strings.ToLower(f), true
		}
	}

	// Fallback: check all params for an object that includes "from"
	for _, p := range req.Params {
		if obj, ok := p.(map[string]interface{}); ok {
			if f, ok2 := obj["from"].(string); ok2 && f != "" {
				return strings.ToLower(f), true
			}
		}
	}
	return "", false
}

func accountExists(address string) bool {
	var count int
	query := "SELECT COUNT(1) FROM tb_account_info WHERE account_address = ? AND is_deleted = 0"
	err := db.QueryRow(query, address).Scan(&count)
	if err != nil {
		if err == sql.ErrNoRows {
			return false
		}
		log.Printf("Database query error: %v", err)
		// fail open (treat as exists to allow forwarding)
		return true
	}
	return count > 0
}

// Forward request body to backend RPC and copy response headers/body.
func forwardToBackend(w http.ResponseWriter, body []byte) {
	if rpcBackend == "" {
		http.Error(w, "Backend RPC not configured", http.StatusServiceUnavailable)
		return
	}
	resp, err := http.Post(rpcBackend, "application/json", bytes.NewReader(body))
	if err != nil {
		http.Error(w, "Backend RPC request failed", http.StatusBadGateway)
		return
	}
	defer resp.Body.Close()
	for k, vals := range resp.Header {
		for _, v := range vals {
			w.Header().Add(k, v)
		}
	}
	w.WriteHeader(resp.StatusCode)
	io.Copy(w, resp.Body)
}

// ===== Whitelist helper functions (dynamic reload) =====
func loadWhitelist(path string) (map[string]struct{}, []*regexp.Regexp) {
	result := make(map[string]struct{})
	patterns := make([]*regexp.Regexp, 0)
	f, err := os.Open(path)
	if err != nil {
		log.Printf("open whitelist file '%s' error: %v", path, err)
		return result, patterns
	}
	defer f.Close()
	scanner := bufio.NewScanner(f)
	lineNum := 0
	for scanner.Scan() {
		lineNum++
		line := strings.TrimSpace(scanner.Text())
		if line == "" || strings.HasPrefix(line, "#") || strings.HasPrefix(line, "//") {
			continue
		}
		// Determine if wildcard pattern present
		if strings.Contains(line, "*") {
			if re := compilePattern(line); re != nil {
				patterns = append(patterns, re)
			} else {
				log.Printf("skip invalid pattern at line %d: %s", lineNum, line)
			}
			continue
		}
		result[line] = struct{}{}
	}
	if err := scanner.Err(); err != nil {
		log.Printf("scan whitelist file error: %v", err)
	}
	return result, patterns
}

func isWhitelisted(v string) bool {
	if v == "" {
		return false
	}
	val := strings.TrimSpace(v)
	if val == "" {
		return false
	}

	// Collect candidate forms: raw, origin base (scheme://host), host (strip port), host:port
	candidates := make([]string, 0, 4)
	candidates = append(candidates, val)
	if u, err := url.Parse(val); err == nil && u.Host != "" {
		host := u.Host
		// strip port for host-only
		if strings.Contains(host, ":") {
			parts := strings.Split(host, ":")
			hostNoPort := parts[0]
			candidates = append(candidates, hostNoPort)
		}
		candidates = append(candidates, host)
		base := fmt.Sprintf("%s://%s", u.Scheme, host)
		candidates = append(candidates, base)
	}

	whitelistMu.RLock()
	defer whitelistMu.RUnlock()
	for _, c := range candidates {
		if _, ok := whitelist[c]; ok {
			return true
		}
	}
	// pattern matching
	for _, re := range whitelistPatterns {
		for _, c := range candidates {
			if re.MatchString(c) {
				return true
			}
		}
	}
	return false
}

// compilePattern converts a whitelist line with '*' wildcards to a safe anchored regexp.
// Supported examples:
//
//	*.example.com    -> subdomains of example.com
//	example.com*     -> prefix match
//	*example.com     -> suffix match
//	*mid*            -> substring match
//	https://*.foo.bar -> scheme + subdomain
func compilePattern(p string) *regexp.Regexp {
	p = strings.TrimSpace(p)
	if p == "" {
		return nil
	}
	// Special host wildcard prefix '*.'
	if strings.HasPrefix(p, "*.") {
		// Allow one or more subdomain levels
		root := strings.TrimPrefix(p, "*.")
		escaped := regexp.QuoteMeta(root)
		pattern := fmt.Sprintf(`^(?:[^.]+\.)+%s$`, escaped)
		re, err := regexp.Compile(pattern)
		if err != nil {
			log.Printf("compile pattern error (%s): %v", p, err)
			return nil
		}
		return re
	}
	// General case: escape then replace '*' with '.*'
	esc := regexp.QuoteMeta(p)
	esc = strings.ReplaceAll(esc, `*`, `.*`)
	pattern := fmt.Sprintf("^%s$", esc)
	re, err := regexp.Compile(pattern)
	if err != nil {
		log.Printf("compile pattern error (%s): %v", p, err)
		return nil
	}
	return re
}

func startWhitelistWatcher(path string) {
	watcher, err := fsnotify.NewWatcher()
	if err != nil {
		log.Printf("create whitelist watcher error: %v", err)
		return
	}
	dir := filepath.Dir(path)
	if err := watcher.Add(dir); err != nil {
		log.Printf("add whitelist watch dir error: %v", err)
		watcher.Close()
		return
	}
	go func() {
		defer watcher.Close()
		for {
			select {
			case ev, ok := <-watcher.Events:
				if !ok {
					return
				}
				if ev.Name == path {
					if ev.Op&(fsnotify.Write|fsnotify.Create|fsnotify.Rename) != 0 {
						exact, pats := loadWhitelist(path)
						whitelistMu.Lock()
						whitelist = exact
						whitelistPatterns = pats
						whitelistMu.Unlock()
						log.Printf("whitelist reloaded (%d exact, %d patterns) due to event: %s", len(exact), len(pats), ev.Op.String())
					}
					if ev.Op&fsnotify.Remove != 0 {
						whitelistMu.Lock()
						whitelist = map[string]struct{}{}
						whitelistPatterns = []*regexp.Regexp{}
						whitelistMu.Unlock()
						log.Printf("whitelist file removed, cleared entries")
					}
				}
			case err, ok := <-watcher.Errors:
				if !ok {
					return
				}
				log.Printf("whitelist watcher error: %v", err)
			}
		}
	}()
}
