Commit c802d497 authored by vicotor's avatar vicotor

add cors support

parent d34c2e55
...@@ -35,9 +35,9 @@ func (t *TbAccountInfo) TableName() string { ...@@ -35,9 +35,9 @@ func (t *TbAccountInfo) TableName() string {
return "tb_account_info" return "tb_account_info"
} }
// RPC请求结构体 // RPC request structure
// 这里只处理eth_getBalance的params // Only handle params for eth_getBalance here
// 其他方法直接转发 // Forward other methods directly
type RPCRequest struct { type RPCRequest struct {
Jsonrpc string `json:"jsonrpc"` Jsonrpc string `json:"jsonrpc"`
...@@ -55,36 +55,45 @@ type RPCResponse struct { ...@@ -55,36 +55,45 @@ type RPCResponse struct {
var ( var (
db *sql.DB db *sql.DB
rpcBackend = os.Getenv("ETH_RPC_BACKEND") // 真实以太坊RPC地址,建议用环境变量配置 rpcBackend = os.Getenv("ETH_RPC_BACKEND") // Real Ethereum RPC address, recommend using environment variable
) )
func main() { func main() {
// 初始化数据库连接 // Initialize database connection
var err error var err error
dsn := os.Getenv("MYSQL_DSN") // 例如 "user:password@tcp(127.0.0.1:3306)/dbname" dsn := os.Getenv("MYSQL_DSN") // Example: "user:password@tcp(127.0.0.1:3306)/dbname"
db, err = sql.Open("mysql", dsn) db, err = sql.Open("mysql", dsn)
if err != nil { if err != nil {
log.Fatalf("数据库连接失败: %v", err) log.Fatalf("Database connection failed: %v", err)
} }
defer db.Close() defer db.Close()
http.HandleFunc("/", proxyHandler) http.HandleFunc("/", proxyHandler)
log.Println("RPC代理服务启动,监听端口: 8545") log.Println("RPC proxy service started, listening on port: 8545")
log.Fatal(http.ListenAndServe(":8545", nil)) log.Fatal(http.ListenAndServe(":8545", nil))
} }
func proxyHandler(w http.ResponseWriter, r *http.Request) { func proxyHandler(w http.ResponseWriter, r *http.Request) {
// Add CORS support
setCORSHeaders(w, r)
// Handle preflight requests
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
body, err := ioutil.ReadAll(r.Body) body, err := ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
http.Error(w, "读取请求失败", http.StatusBadRequest) http.Error(w, "Failed to read request", http.StatusBadRequest)
return return
} }
defer r.Body.Close() defer r.Body.Close()
var reqs []RPCRequest var reqs []RPCRequest
// 先尝试解析为批量请求 // Try to parse as batch request first
if err := json.Unmarshal(body, &reqs); err == nil { if err := json.Unmarshal(body, &reqs); err == nil {
// 处理批量请求 // Handle batch request
if len(reqs) > 1 { if len(reqs) > 1 {
req := reqs[0] req := reqs[0]
resp := RPCResponse{ resp := RPCResponse{
...@@ -128,10 +137,24 @@ func proxyHandler(w http.ResponseWriter, r *http.Request) { ...@@ -128,10 +137,24 @@ func proxyHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
} }
// 其他情况直接转发 // Forward other cases directly
forwardToBackend(w, body) forwardToBackend(w, body)
} }
func setCORSHeaders(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
if origin != "" {
w.Header().Set("Access-Control-Allow-Origin", origin)
} else {
w.Header().Set("Access-Control-Allow-Origin", "*")
}
w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
w.Header().Set("Access-Control-Allow-Headers", "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, X-Real-IP, X-Forwarded-For")
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Access-Control-Max-Age", "86400")
}
func accountExists(address string) bool { func accountExists(address string) bool {
var count int var count int
query := "SELECT COUNT(1) FROM tb_account_info WHERE account_address = ? AND is_deleted = 0" query := "SELECT COUNT(1) FROM tb_account_info WHERE account_address = ? AND is_deleted = 0"
...@@ -140,7 +163,7 @@ func accountExists(address string) bool { ...@@ -140,7 +163,7 @@ func accountExists(address string) bool {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return false return false
} else { } else {
log.Printf("数据库查询错误: %v", err) log.Printf("Database query error: %v", err)
} }
return true return true
} }
...@@ -150,10 +173,18 @@ func accountExists(address string) bool { ...@@ -150,10 +173,18 @@ func accountExists(address string) bool {
func forwardToBackend(w http.ResponseWriter, body []byte) { func forwardToBackend(w http.ResponseWriter, body []byte) {
resp, err := http.Post(rpcBackend, "application/json", bytes.NewReader(body)) resp, err := http.Post(rpcBackend, "application/json", bytes.NewReader(body))
if err != nil { if err != nil {
http.Error(w, "后端RPC请求失败", http.StatusBadGateway) http.Error(w, "Backend RPC request failed", http.StatusBadGateway)
return return
} }
defer resp.Body.Close() defer resp.Body.Close()
w.Header().Set("Content-Type", "application/json")
// Copy response headers, including possible CORS headers
for key, values := range resp.Header {
for _, value := range values {
w.Header().Add(key, value)
}
}
w.WriteHeader(resp.StatusCode)
io.Copy(w, resp.Body) io.Copy(w, resp.Body)
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment