Commit a602c5c3 authored by Michael Amadi's avatar Michael Amadi Committed by GitHub

transition interface check script to common framework (#13390)

* transition interface check script to common framework

* fixes

* fixes
parent 07b6011c
...@@ -7,12 +7,10 @@ import ( ...@@ -7,12 +7,10 @@ import (
"log" "log"
"os" "os"
"path/filepath" "path/filepath"
"runtime"
"sort" "sort"
"strings" "strings"
"sync"
"sync/atomic"
"github.com/ethereum-optimism/optimism/packages/contracts-bedrock/scripts/checks/common"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
) )
...@@ -52,82 +50,32 @@ type Artifact struct { ...@@ -52,82 +50,32 @@ type Artifact struct {
} }
func main() { func main() {
if err := run(); err != nil { if err := common.ProcessFilesGlob(
writeStderr("an error occurred: %v", err) []string{"forge-artifacts/**/*.json"},
[]string{},
processFile,
); err != nil {
fmt.Printf("error: %v\n", err)
os.Exit(1) os.Exit(1)
} }
} }
func writeStderr(msg string, args ...any) { func processFile(artifactPath string) []error {
_, _ = fmt.Fprintf(os.Stderr, msg+"\n", args...)
}
func run() error {
cwd, err := os.Getwd() cwd, err := os.Getwd()
if err != nil { if err != nil {
return fmt.Errorf("failed to get current working directory: %w", err) return []error{fmt.Errorf("failed to get current working directory: %w", err)}
} }
artifactsDir := filepath.Join(cwd, "forge-artifacts") artifactsDir := filepath.Join(cwd, "forge-artifacts")
artifactFiles, err := glob(artifactsDir, ".json") contractName := strings.Split(filepath.Base(artifactPath), ".")[0]
if err != nil {
return fmt.Errorf("failed to get artifact files: %w", err)
}
// Remove duplicates from artifactFiles
uniqueArtifacts := make(map[string]string)
for contractName, artifactPath := range artifactFiles {
baseName := strings.Split(contractName, ".")[0]
uniqueArtifacts[baseName] = artifactPath
}
var hasErr int32
var outMtx sync.Mutex
fail := func(msg string, args ...any) {
outMtx.Lock()
writeStderr("❌ "+msg, args...)
outMtx.Unlock()
atomic.StoreInt32(&hasErr, 1)
}
sem := make(chan struct{}, runtime.NumCPU())
for contractName, artifactPath := range uniqueArtifacts {
contractName := contractName
artifactPath := artifactPath
sem <- struct{}{}
go func() {
defer func() {
<-sem
}()
if err := processArtifact(contractName, artifactPath, artifactsDir, fail); err != nil {
fail("%s: %v", contractName, err)
}
}()
}
for i := 0; i < cap(sem); i++ {
sem <- struct{}{}
}
if atomic.LoadInt32(&hasErr) == 1 {
return errors.New("interface check failed, see logs above")
}
return nil
}
func processArtifact(contractName, artifactPath, artifactsDir string, fail func(string, ...any)) error {
if isExcluded(contractName) { if isExcluded(contractName) {
return nil return nil
} }
artifact, err := readArtifact(artifactPath) artifact, err := readArtifact(artifactPath)
if err != nil { if err != nil {
return fmt.Errorf("failed to read artifact: %w", err) return []error{fmt.Errorf("failed to read artifact: %w", err)}
} }
contractDef := getContractDefinition(artifact, contractName) contractDef := getContractDefinition(artifact, contractName)
...@@ -140,16 +88,16 @@ func processArtifact(contractName, artifactPath, artifactsDir string, fail func( ...@@ -140,16 +88,16 @@ func processArtifact(contractName, artifactPath, artifactsDir string, fail func(
} }
if !strings.HasPrefix(contractName, "I") { if !strings.HasPrefix(contractName, "I") {
fail("%s: Interface does not start with 'I'", contractName) return []error{fmt.Errorf("%s: Interface does not start with 'I'", contractName)}
} }
semver, err := getContractSemver(artifact) semver, err := getContractSemver(artifact)
if err != nil { if err != nil {
return err return []error{fmt.Errorf("failed to get contract semver: %w", err)}
} }
if semver != "solidity^0.8.0" { if semver != "solidity^0.8.0" {
fail("%s: Interface does not have correct compiler version (MUST be exactly solidity ^0.8.0)", contractName) return []error{fmt.Errorf("%s: Interface does not have correct compiler version (MUST be exactly solidity ^0.8.0)", contractName)}
} }
contractBasename := contractName[1:] contractBasename := contractName[1:]
...@@ -161,7 +109,7 @@ func processArtifact(contractName, artifactPath, artifactsDir string, fail func( ...@@ -161,7 +109,7 @@ func processArtifact(contractName, artifactPath, artifactsDir string, fail func(
contractArtifact, err := readArtifact(correspondingContractFile) contractArtifact, err := readArtifact(correspondingContractFile)
if err != nil { if err != nil {
return fmt.Errorf("failed to read corresponding contract artifact: %w", err) return []error{fmt.Errorf("failed to read corresponding contract artifact: %w", err)}
} }
interfaceABI := artifact.ABI interfaceABI := artifact.ABI
...@@ -169,20 +117,20 @@ func processArtifact(contractName, artifactPath, artifactsDir string, fail func( ...@@ -169,20 +117,20 @@ func processArtifact(contractName, artifactPath, artifactsDir string, fail func(
normalizedInterfaceABI, err := normalizeABI(interfaceABI) normalizedInterfaceABI, err := normalizeABI(interfaceABI)
if err != nil { if err != nil {
return fmt.Errorf("failed to normalize interface ABI: %w", err) return []error{fmt.Errorf("failed to normalize interface ABI: %w", err)}
} }
normalizedContractABI, err := normalizeABI(contractABI) normalizedContractABI, err := normalizeABI(contractABI)
if err != nil { if err != nil {
return fmt.Errorf("failed to normalize contract ABI: %w", err) return []error{fmt.Errorf("failed to normalize contract ABI: %w", err)}
} }
match, err := compareABIs(normalizedInterfaceABI, normalizedContractABI) match, err := compareABIs(normalizedInterfaceABI, normalizedContractABI)
if err != nil { if err != nil {
return fmt.Errorf("failed to compare ABIs: %w", err) return []error{fmt.Errorf("failed to compare ABIs: %w", err)}
} }
if !match { if !match {
fail("%s: Differences found in ABI between interface and actual contract", contractName) return []error{fmt.Errorf("%s: Differences found in ABI between interface and actual contract", contractName)}
} }
return nil return nil
...@@ -338,17 +286,3 @@ func isExcluded(contractName string) bool { ...@@ -338,17 +286,3 @@ func isExcluded(contractName string) bool {
} }
return false return false
} }
func glob(dir string, ext string) (map[string]string, error) {
out := make(map[string]string)
err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
if !info.IsDir() && filepath.Ext(path) == ext {
out[strings.TrimSuffix(filepath.Base(path), ext)] = path
}
return nil
})
if err != nil {
return nil, fmt.Errorf("failed to walk directory: %w", err)
}
return out, nil
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment