forked from leonid-shevtsov/split_tests
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.go
168 lines (143 loc) · 4.68 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
package main
import (
"flag"
"fmt"
"os"
"strconv"
"strings"
"github.com/bmatcuk/doublestar"
)
var useCircleCI bool
var useJUnitXML bool
var useLineCount bool
var junitXMLPath string
var testFilePattern = ""
var excludeFilePattern = ""
var circleCIProjectPrefix = ""
var circleCIBranchName string
var splitIndex int
var splitTotal int
var circleCIAPIKey string
func printMsg(msg string, args ...interface{}) {
if len(args) == 0 {
fmt.Fprint(os.Stderr, msg)
} else {
fmt.Fprintf(os.Stderr, msg, args...)
}
}
func fatalMsg(msg string, args ...interface{}) {
printMsg(msg, args...)
os.Exit(1)
}
func removeDeletedFiles(fileTimes map[string]float64, currentFileSet map[string]bool) {
for file := range fileTimes {
if !currentFileSet[file] {
delete(fileTimes, file)
}
}
}
func addNewFiles(fileTimes map[string]float64, currentFileSet map[string]bool) {
averageFileTime := 0.0
if len(fileTimes) > 0 {
for _, time := range fileTimes {
averageFileTime += time
}
averageFileTime /= float64(len(fileTimes))
} else {
averageFileTime = 1.0
}
for file := range currentFileSet {
if _, isSet := fileTimes[file]; isSet {
continue
}
if useCircleCI || useJUnitXML {
printMsg("missing file time for %s\n", file)
}
fileTimes[file] = averageFileTime
}
}
func parseFlags() {
flag.StringVar(&testFilePattern, "glob", "spec/**/*_spec.rb", "Glob pattern to find test files")
flag.StringVar(&excludeFilePattern, "exclude-glob", "", "Glob pattern to exclude test files")
flag.IntVar(&splitIndex, "split-index", -1, "This test container's index (or set CIRCLE_NODE_INDEX)")
flag.IntVar(&splitTotal, "split-total", -1, "Total number of containers (or set CIRCLE_NODE_TOTAL)")
flag.StringVar(&circleCIAPIKey, "circleci-key", "", "CircleCI API key (or set CIRCLECI_API_KEY environment variable) - required to use CircleCI")
flag.StringVar(&circleCIProjectPrefix, "circleci-project", "", "CircleCI project name (e.g. github/leonid-shevtsov/split_tests) - required to use CircleCI")
flag.StringVar(&circleCIBranchName, "circleci-branch", "", "Current branch for CircleCI (or set CIRCLE_BRANCH) - required to use CircleCI")
flag.BoolVar(&useJUnitXML, "junit", false, "Use a JUnit XML report for test times")
flag.StringVar(&junitXMLPath, "junit-path", "", "Path to a JUnit XML report (leave empty to read from stdin; use glob pattern to load multiple files)")
flag.BoolVar(&useLineCount, "line-count", false, "Use line count to estimate test times")
var showHelp bool
flag.BoolVar(&showHelp, "help", false, "Show this help text")
flag.Parse()
var err error
if circleCIAPIKey == "" {
circleCIAPIKey = os.Getenv("CIRCLECI_API_KEY")
}
if circleCIBranchName == "" {
circleCIBranchName = os.Getenv("CIRCLE_BRANCH")
}
if splitTotal == -1 {
splitTotal, err = strconv.Atoi(os.Getenv("CIRCLE_NODE_TOTAL"))
if err != nil {
splitIndex = -1
}
}
if splitIndex == -1 {
splitIndex, err = strconv.Atoi(os.Getenv("CIRCLE_NODE_INDEX"))
if err != nil {
splitIndex = -1
}
}
useCircleCI = circleCIAPIKey != ""
if showHelp {
printMsg("Splits test files into containers of even duration\n\n")
flag.PrintDefaults()
os.Exit(1)
}
if useCircleCI && (circleCIProjectPrefix == "" || circleCIBranchName == "") {
fatalMsg("Incomplete CircleCI configuration (set -circleci-key, -circleci-project, and -circleci-branch\n")
}
if splitTotal == 0 || splitIndex < 0 || splitIndex > splitTotal {
fatalMsg("-split-index and -split-total (and environment variables) are missing or invalid\n")
}
}
func main() {
parseFlags()
// We are not using filepath.Glob,
// because it doesn't support '**' (to match all files in all nested directories)
currentFiles, err := doublestar.Glob(testFilePattern)
if err != nil {
printMsg("failed to enumerate current file set: %v", err)
os.Exit(1)
}
currentFileSet := make(map[string]bool)
for _, file := range currentFiles {
currentFileSet[file] = true
}
if excludeFilePattern != "" {
excludedFiles, err := doublestar.Glob(excludeFilePattern)
if err != nil {
printMsg("failed to enumerate excluded file set: %v", err)
os.Exit(1)
}
for _, file := range excludedFiles {
delete(currentFileSet, file)
}
}
fileTimes := make(map[string]float64)
if useLineCount {
estimateFileTimesByLineCount(currentFileSet, fileTimes)
} else if useJUnitXML {
getFileTimesFromJUnitXML(fileTimes)
} else if useCircleCI {
getFileTimesFromCircleCI(fileTimes)
}
removeDeletedFiles(fileTimes, currentFileSet)
addNewFiles(fileTimes, currentFileSet)
buckets, bucketTimes := splitFiles(fileTimes, splitTotal)
if useCircleCI || useJUnitXML {
printMsg("expected test time: %0.1fs\n", bucketTimes[splitIndex])
}
fmt.Println(strings.Join(buckets[splitIndex], " "))
}