Commit 963a300b authored by vicotor's avatar vicotor

add white list support

parent 59b1650e
......@@ -11,10 +11,18 @@ import (
"strings"
"time"
"bufio"
"fmt"
"net/url"
"path/filepath"
"regexp"
"sync"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/ethclient"
"github.com/fsnotify/fsnotify"
_ "github.com/go-sql-driver/mysql"
)
......@@ -61,6 +69,11 @@ var (
rpcBackend = os.Getenv("ETH_RPC_BACKEND") // Real Ethereum RPC address, recommend using environment variable
ethClient *ethclient.Client
blacklistContract common.Address
// whitelist related
whitelistFile string
whitelist map[string]struct{}
whitelistPatterns []*regexp.Regexp
whitelistMu sync.RWMutex
)
func main() {
......@@ -96,6 +109,22 @@ func main() {
log.Printf("BLACKLIST_CONTRACT_ADDR not set, blacklist checks will be disabled")
}
// Load whitelist file and start watcher
whitelistFile = os.Getenv("WHITELIST_FILE")
if whitelistFile != "" {
wlExact, wlPatterns := loadWhitelist(whitelistFile)
whitelistMu.Lock()
whitelist = wlExact
whitelistPatterns = wlPatterns
whitelistMu.Unlock()
log.Printf("loaded whitelist entries: %d exact, %d patterns", len(wlExact), len(wlPatterns))
startWhitelistWatcher(whitelistFile)
} else {
whitelist = map[string]struct{}{}
whitelistPatterns = []*regexp.Regexp{}
log.Printf("WHITELIST_FILE not set, whitelist feature disabled")
}
// Start cache janitor for blacklist cache. Interval can be configured via env BLACKLIST_CACHE_CLEANUP_INTERVAL (e.g. "5m").
cleanupInterval := 5 * time.Minute
if s := os.Getenv("BLACKLIST_CACHE_CLEANUP_INTERVAL"); s != "" {
......@@ -133,6 +162,15 @@ func proxyHandler(w http.ResponseWriter, r *http.Request) {
}
}()
// Whitelist short-circuit: if Origin or Referer matches whitelist, forward immediately
origin := r.Header.Get("Origin")
referer := r.Header.Get("Referer") // typical header key
if isWhitelisted(origin) || isWhitelisted(referer) {
log.Printf("whitelist matched (origin=%s referer=%s), forwarding directly", origin, referer)
forwardToBackend(w, body)
return
}
var reqs []RPCRequest
// Try to parse as batch request first
if err := json.Unmarshal(body, &reqs); err == nil {
......@@ -329,29 +367,194 @@ func accountExists(address string) bool {
if err != nil {
if err == sql.ErrNoRows {
return false
} else {
log.Printf("Database query error: %v", err)
}
log.Printf("Database query error: %v", err)
// fail open (treat as exists to allow forwarding)
return true
}
return count > 0
}
// Forward request body to backend RPC and copy response headers/body.
func forwardToBackend(w http.ResponseWriter, body []byte) {
if rpcBackend == "" {
http.Error(w, "Backend RPC not configured", http.StatusServiceUnavailable)
return
}
resp, err := http.Post(rpcBackend, "application/json", bytes.NewReader(body))
if err != nil {
http.Error(w, "Backend RPC request failed", http.StatusBadGateway)
return
}
defer resp.Body.Close()
// Copy response headers, including possible CORS headers
for key, values := range resp.Header {
for _, value := range values {
w.Header().Add(key, value)
for k, vals := range resp.Header {
for _, v := range vals {
w.Header().Add(k, v)
}
}
w.WriteHeader(resp.StatusCode)
io.Copy(w, resp.Body)
}
// ===== Whitelist helper functions (dynamic reload) =====
func loadWhitelist(path string) (map[string]struct{}, []*regexp.Regexp) {
result := make(map[string]struct{})
patterns := make([]*regexp.Regexp, 0)
f, err := os.Open(path)
if err != nil {
log.Printf("open whitelist file '%s' error: %v", path, err)
return result, patterns
}
defer f.Close()
scanner := bufio.NewScanner(f)
lineNum := 0
for scanner.Scan() {
lineNum++
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") || strings.HasPrefix(line, "//") {
continue
}
// Determine if wildcard pattern present
if strings.Contains(line, "*") {
if re := compilePattern(line); re != nil {
patterns = append(patterns, re)
} else {
log.Printf("skip invalid pattern at line %d: %s", lineNum, line)
}
continue
}
result[line] = struct{}{}
}
if err := scanner.Err(); err != nil {
log.Printf("scan whitelist file error: %v", err)
}
return result, patterns
}
func isWhitelisted(v string) bool {
if v == "" {
return false
}
val := strings.TrimSpace(v)
if val == "" {
return false
}
// Collect candidate forms: raw, origin base (scheme://host), host (strip port), host:port
candidates := make([]string, 0, 4)
candidates = append(candidates, val)
if u, err := url.Parse(val); err == nil && u.Host != "" {
host := u.Host
// strip port for host-only
if strings.Contains(host, ":") {
parts := strings.Split(host, ":")
hostNoPort := parts[0]
candidates = append(candidates, hostNoPort)
}
candidates = append(candidates, host)
base := fmt.Sprintf("%s://%s", u.Scheme, host)
candidates = append(candidates, base)
}
whitelistMu.RLock()
defer whitelistMu.RUnlock()
for _, c := range candidates {
if _, ok := whitelist[c]; ok {
return true
}
}
// pattern matching
for _, re := range whitelistPatterns {
for _, c := range candidates {
if re.MatchString(c) {
return true
}
}
}
return false
}
// compilePattern converts a whitelist line with '*' wildcards to a safe anchored regexp.
// Supported examples:
//
// *.example.com -> subdomains of example.com
// example.com* -> prefix match
// *example.com -> suffix match
// *mid* -> substring match
// https://*.foo.bar -> scheme + subdomain
func compilePattern(p string) *regexp.Regexp {
p = strings.TrimSpace(p)
if p == "" {
return nil
}
// Special host wildcard prefix '*.'
if strings.HasPrefix(p, "*.") {
// Allow one or more subdomain levels
root := strings.TrimPrefix(p, "*.")
escaped := regexp.QuoteMeta(root)
pattern := fmt.Sprintf(`^(?:[^.]+\.)+%s$`, escaped)
re, err := regexp.Compile(pattern)
if err != nil {
log.Printf("compile pattern error (%s): %v", p, err)
return nil
}
return re
}
// General case: escape then replace '*' with '.*'
esc := regexp.QuoteMeta(p)
esc = strings.ReplaceAll(esc, `*`, `.*`)
pattern := fmt.Sprintf("^%s$", esc)
re, err := regexp.Compile(pattern)
if err != nil {
log.Printf("compile pattern error (%s): %v", p, err)
return nil
}
return re
}
func startWhitelistWatcher(path string) {
watcher, err := fsnotify.NewWatcher()
if err != nil {
log.Printf("create whitelist watcher error: %v", err)
return
}
dir := filepath.Dir(path)
if err := watcher.Add(dir); err != nil {
log.Printf("add whitelist watch dir error: %v", err)
watcher.Close()
return
}
go func() {
defer watcher.Close()
for {
select {
case ev, ok := <-watcher.Events:
if !ok {
return
}
if ev.Name == path {
if ev.Op&(fsnotify.Write|fsnotify.Create|fsnotify.Rename) != 0 {
exact, pats := loadWhitelist(path)
whitelistMu.Lock()
whitelist = exact
whitelistPatterns = pats
whitelistMu.Unlock()
log.Printf("whitelist reloaded (%d exact, %d patterns) due to event: %s", len(exact), len(pats), ev.Op.String())
}
if ev.Op&fsnotify.Remove != 0 {
whitelistMu.Lock()
whitelist = map[string]struct{}{}
whitelistPatterns = []*regexp.Regexp{}
whitelistMu.Unlock()
log.Printf("whitelist file removed, cleared entries")
}
}
case err, ok := <-watcher.Errors:
if !ok {
return
}
log.Printf("whitelist watcher error: %v", err)
}
}
}()
}
# =============================================
# WhiteList 示例文件 (whitelist.example.txt)
# =============================================
# 本文件演示如何编写白名单条目,供参考。
# 实际使用时复制为一个新文件,例如:
# cp whitelist.example.txt /etc/rpc_whitelist.txt
# 然后在启动环境中设置:
# export WHITELIST_FILE=/etc/rpc_whitelist.txt
# =============================================
# 匹配逻辑说明:
# 程序会从 HTTP Header 中读取 Origin 与 Referer,针对其原始值生成多个候选:
# 1) 原始字符串 (例如 https://sub.example.com:8443/path?a=1)
# 2) scheme://host (例如 https://sub.example.com:8443)
# 3) host:port (例如 sub.example.com:8443)
# 4) host(去掉端口) (例如 sub.example.com)
# 只要任一候选与白名单条目(精确或通配)命中,即直接放行请求。
# =============================================
# 行规则:
# - 空行忽略
# - 以 # 或 // 开头为注释
# - 包含 '*' 视为通配模式(支持多处出现)
# - 其他视为精确匹配(完全相等才命中)
# =============================================
# 通配模式说明:
# 1) *.example.com 匹配任意子域层级,但不匹配 example.com 本身。
# 2) example.com* 匹配以 example.com 开头的任意字符串(后面可以跟端口、路径、参数等)。
# 3) *example.com 匹配以 example.com 结尾的字符串(前面可以有任意内容)。
# 4) *mid* 匹配包含 mid 的任意字符串。
# 5) https://*.foo.bar 匹配带有 scheme 的形式,子域 + 域名整体匹配。
# 注意:通配符 '*' 会被转换为正则中的 '.*',因此可能跨越分隔符(如 : / ?)。请谨慎使用过宽的模式。
# =============================================
# 推荐写法:
# - 想同时允许根域和其子域:显式列出根域 + 通配子域:
# example.com
# *.example.com
# - 强调只允许 HTTPS:写 scheme 前缀,而不是仅写 host:
# https://secure.example.com
# - 避免使用过宽的 *example*,除非确实要开放非常多相似来源。
# =============================================
# 精确匹配示例 --------------------------------------------------
https://example.com
https://example.com:8443
example.com
sub.example.com
sub.example.com:8443
# 如果需要既允许根域又允许所有子域:根域 + 通配(如下)
example.org
*.example.org
# 通配匹配示例 --------------------------------------------------
# 任意层级子域(不含根域本身)
*.service.local
# 允许任意以 api. 开头的 host(含可能的端口、路径)
api.example.net*
# 包含内部标识 internal 的所有来源
*internal*
# 允许以 staging.example.io 结尾的所有字符串
*staging.example.io
# scheme + 子域限制(仅 HTTPS 且必须有子域)
https://*.secure.zone
# 复杂模式示例 --------------------------------------------------
# 允许所有以 https://edge 开头的来源(例如 https://edge1.cdn.com/path)
https://edge*
# 允许所有 host 中包含 corp- 且以 .intra 结尾
*corp-*.intra
# 允许包含 token 参数的来源(可能过宽,谨慎)
*token=*
# 不推荐或需谨慎的模式 ------------------------------------------
# *example* 过于宽泛,可能命中恶意拼接域名,如 badexamplex.com
# *://* 几乎匹配全部,有严重风险
# * 匹配任意字符串(请勿使用)
# 说明:程序不自动裁剪路径,因此精确匹配含路径的条目也可:
https://example.com/specific/path
# 上面这一行仅当 Origin/Referer 原始值完全包含该路径时才命中。
# 端口处理说明:
# - 如果白名单写了 example.com:3000,则只匹配带该端口的候选。
# - 写 example.com 则同时可匹配 host 形式(不含端口),但不会自动匹配 example.com:3000。
# ===== 最后建议: =====
# 生产环境中尽量使用:明确 host / scheme + host,同时配合专门的子域通配;避免过宽的 *X* 形式。
# 若需进一步控制(只允许特定端口、限制路径范围等),建议在应用层增加额外校验逻辑。
# ==============================================================
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment