diff options
Diffstat (limited to 'main.go')
| -rw-r--r-- | main.go | 559 |
1 files changed, 559 insertions, 0 deletions
@@ -0,0 +1,559 @@ +package main + +import ( + "bytes" + "encoding/json" + "flag" + "fmt" + "io" + "log" + "math/rand" + "mime/multipart" + "net/http" + "net/url" + "os" + "regexp" + "strconv" + "strings" + "time" + + "github.com/joho/godotenv" + "github.com/mattn/go-colorable" + "github.com/neilotoole/jsoncolor" + "gopkg.in/yaml.v3" +) + +const ( + colorReset = "\033[0m" + colorRed = "\033[31m" + colorGreen = "\033[32m" + colorYellow = "\033[33m" + colorBlue = "\033[34m" + colorPurple = "\033[35m" + colorCyan = "\033[36m" + colorWhite = "\033[37m" + colorBold = "\033[1m" +) + +// Config represents the Hepi configuration file structure. +type Config struct { + Environments yaml.Node `yaml:"environments"` + Requests yaml.Node `yaml:"requests"` + Groups map[string][]string `yaml:"groups"` +} + +// Request represents an individual API request definition. +type Request struct { + Method string `yaml:"method"` + URL string `yaml:"url"` + Description string `yaml:"description"` + Headers map[string]string `yaml:"headers"` + Params map[string]interface{} `yaml:"params"` + JSON map[string]interface{} `yaml:"json"` + Form map[string]interface{} `yaml:"form"` + Files map[string]string `yaml:"files"` +} + +// Runner manages the execution of API requests. +type Runner struct { + Config Config + EnvName string + Environment map[string]interface{} + Results map[string]interface{} + HTTPClient *http.Client + ShowHeaders bool +} + +const resultsFile = ".hepi.json" + +func main() { + godotenv.Load() + + envName := flag.String("env", "", "Environment to use") + filePath := flag.String("file", "", "Path to the YAML file") + reqNames := flag.String("req", "", "Comma-separated list of request names to execute") + groupName := flag.String("group", "", "Group to execute") + showHeaders := flag.Bool("headers", false, "Display response headers") + flag.Parse() + + if *filePath == "" || *envName == "" { + fmt.Printf("Error: -file and -env are required\n\n") + fmt.Printf("Usage: %s -env <environment> -file <file_path> [options]\n", os.Args[0]) + os.Exit(1) + } + + runner, err := NewRunner(*filePath, *envName) + if err != nil { + log.Fatalf("Error: %v", err) + } + runner.ShowHeaders = *showHeaders + + if *reqNames == "" && *groupName == "" { + runner.PrintHelp() + return + } + + if *groupName != "" { + if err := runner.ExecuteGroup(*groupName); err != nil { + log.Fatalf("Error: %v", err) + } + } + + if *reqNames != "" { + if err := runner.ExecuteRequests(*reqNames); err != nil { + log.Fatalf("Error: %v", err) + } + } +} + +// NewRunner initializes a new Hepi runner. +func NewRunner(filePath, envName string) (*Runner, error) { + data, err := os.ReadFile(filePath) + if err != nil { + return nil, fmt.Errorf("failed to read file: %w", err) + } + + var config Config + if err := yaml.Unmarshal(data, &config); err != nil { + return nil, fmt.Errorf("failed to parse YAML: %w", err) + } + + if config.Environments.Kind != yaml.MappingNode { + return nil, fmt.Errorf("environments must be a mapping") + } + + selectedEnvName := envName + var selectedEnv map[string]interface{} + found := false + for i := 0; i < len(config.Environments.Content); i += 2 { + if config.Environments.Content[i].Value == envName { + if err := config.Environments.Content[i+1].Decode(&selectedEnv); err != nil { + return nil, fmt.Errorf("failed to decode environment %q: %w", envName, err) + } + found = true + break + } + } + if !found { + return nil, fmt.Errorf("environment %q not found", envName) + } + + return &Runner{ + Config: config, + EnvName: selectedEnvName, + Environment: selectedEnv, + Results: loadResults(selectedEnvName), + HTTPClient: &http.Client{Timeout: 10 * time.Second}, + }, nil +} + +// ExecuteGroup runs all requests in the specified group. +func (r *Runner) ExecuteGroup(groupName string) error { + group, ok := r.Config.Groups[groupName] + if !ok { + return fmt.Errorf("group %q not found", groupName) + } + + for _, reqName := range group { + if err := r.ExecuteRequests(reqName); err != nil { + return err + } + } + + return nil +} + +// ExecuteRequests runs the specified requests. +func (r *Runner) ExecuteRequests(reqNames string) error { + filter := make(map[string]bool) + for _, name := range strings.Split(reqNames, ",") { + filter[strings.TrimSpace(name)] = true + } + + requestsNode := r.Config.Requests + if requestsNode.Kind != yaml.MappingNode { + return fmt.Errorf("requests must be a mapping") + } + + for i := 0; i < len(requestsNode.Content); i += 2 { + nameNode := requestsNode.Content[i] + valNode := requestsNode.Content[i+1] + + name := nameNode.Value + if !filter[name] { + continue + } + + var req Request + if err := valNode.Decode(&req); err != nil { + return fmt.Errorf("failed to decode request %q: %w", name, err) + } + + fmt.Printf("\n%s--- %s[%s]%s %s ---%s\n", colorBold, colorCyan, name, colorReset, req.Description, colorReset) + if err := r.executeRequest(name, req); err != nil { + log.Printf("Warning: request %q failed: %v", name, err) + } + } + + return nil +} + +func (r *Runner) executeRequest(name string, req Request) error { + rawURL := r.substitute(req.URL) + + // Handle query parameters + if req.Params != nil { + u, err := url.Parse(rawURL) + if err != nil { + return fmt.Errorf("failed to parse URL %q: %w", rawURL, err) + } + q := u.Query() + params := r.substituteMap(req.Params) + for k, v := range params { + q.Set(k, fmt.Sprintf("%v", v)) + } + u.RawQuery = q.Encode() + rawURL = u.String() + } + + methodColor := colorCyan + switch req.Method { + case "GET": + methodColor = colorGreen + case "POST": + methodColor = colorCyan + case "PUT", "PATCH": + methodColor = colorYellow + case "DELETE": + methodColor = colorRed + } + + fmt.Printf("%s%s%s %s\n", methodColor, req.Method, colorReset, rawURL) + + var bodyReader io.Reader + var contentType string + + if req.JSON != nil { + jsonBody := r.substituteMap(req.JSON) + data, _ := json.Marshal(jsonBody) + bodyReader = bytes.NewReader(data) + contentType = "application/json" + } else if req.Files != nil { + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + + // Add form fields + if req.Form != nil { + form := r.substituteMap(req.Form) + for k, v := range form { + _ = writer.WriteField(k, fmt.Sprintf("%v", v)) + } + } + + // Add files + for field, path := range req.Files { + substitutedPath := r.substitute(path) + file, err := os.Open(substitutedPath) + if err != nil { + return fmt.Errorf("failed to open file %q: %w", substitutedPath, err) + } + defer file.Close() + + part, err := writer.CreateFormFile(field, substitutedPath) + if err != nil { + return fmt.Errorf("failed to create form file for %q: %w", field, err) + } + _, _ = io.Copy(part, file) + } + + writer.Close() + bodyReader = body + contentType = writer.FormDataContentType() + } else if req.Form != nil { + formData := url.Values{} + form := r.substituteMap(req.Form) + for k, v := range form { + formData.Set(k, fmt.Sprintf("%v", v)) + } + bodyReader = strings.NewReader(formData.Encode()) + contentType = "application/x-www-form-urlencoded" + } + + httpReq, err := http.NewRequest(req.Method, rawURL, bodyReader) + if err != nil { + return fmt.Errorf("failed to create HTTP request: %w", err) + } + + if contentType != "" { + httpReq.Header.Set("Content-Type", contentType) + } + + for k, v := range req.Headers { + httpReq.Header.Set(k, r.substitute(v)) + } + + startTime := time.Now() + resp, err := r.HTTPClient.Do(httpReq) + if err != nil { + return fmt.Errorf("request failed: %w", err) + } + duration := time.Since(startTime) + defer resp.Body.Close() + + statusColor := colorRed + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + statusColor = colorGreen + } else if resp.StatusCode >= 300 && resp.StatusCode < 400 { + statusColor = colorYellow + } + + fmt.Printf("Status: %s%s%s (took %s%v%s)\n", statusColor, resp.Status, colorReset, colorYellow, duration.Round(time.Millisecond), colorReset) + + if r.ShowHeaders { + fmt.Printf("\n%sHeaders:%s\n", colorBold, colorReset) + for k, v := range resp.Header { + fmt.Printf(" %s%s%s: %s\n", colorCyan, k, colorReset, strings.Join(v, ", ")) + } + } + + respData, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read response body: %w", err) + } + + if len(respData) > 0 { + var result interface{} + if err := json.Unmarshal(respData, &result); err == nil { + result = decodeRecursive(result) + r.Results[name] = result + saveResults(r.EnvName, r.Results) + fmt.Printf("\n%sResponse:%s\n", colorBold, colorReset) + + var enc *jsoncolor.Encoder + if jsoncolor.IsColorTerminal(os.Stdout) { + out := colorable.NewColorable(os.Stdout) + enc = jsoncolor.NewEncoder(out) + enc.SetColors(jsoncolor.DefaultColors()) + } else { + enc = jsoncolor.NewEncoder(os.Stdout) + } + + enc.SetIndent("", " ") + if err := enc.Encode(result); err != nil { + fmt.Println(result) + } + } else { + fmt.Printf("\n%sResponse (non-JSON):%s\n", colorBold, colorReset) + fmt.Println(string(respData)) + } + } + + return nil +} + +func (r *Runner) substitute(s string) string { + // 1. Handle [[dynamic]] placeholders using the Generators map and oneof support + genRegex := regexp.MustCompile(`\[\[(.*?)\]\]`) + s = genRegex.ReplaceAllStringFunc(s, func(match string) string { + tag := strings.Trim(match[2:len(match)-2], " ") + + // Handle [[oneof: a, b, c]] + if strings.HasPrefix(tag, "oneof:") { + parts := strings.Split(tag[6:], ",") + if len(parts) > 0 { + return strings.TrimSpace(parts[rand.Intn(len(parts))]) + } + } + + // Handle Generators map + if gen, ok := Generators[tag]; ok { + return gen() + } + + // Fallback for random_ prefix if not already present + if gen, ok := Generators["random_"+tag]; ok { + return gen() + } + + return match + }) + + // 2. Handle {{variables}} + re := regexp.MustCompile(`{{(.*?)}}`) + return re.ReplaceAllStringFunc(s, func(match string) string { + key := strings.Trim(match[2:len(match)-2], " ") + + // Priority 1: System Environment Variables + if val, exists := os.LookupEnv(key); exists { + return val + } + + // Priority 2: Config Environment Variables + if val, ok := r.Environment[key]; ok { + return fmt.Sprintf("%v", val) + } + + // Priority 3: Previous Request Results + parts := strings.Split(key, ".") + if len(parts) > 1 { + if res, ok := r.Results[parts[0]]; ok { + return getValueFromMap(res, parts[1:]) + } + } + + return match + }) +} + +func (r *Runner) substituteMap(m map[string]interface{}) map[string]interface{} { + res := make(map[string]interface{}) + for k, v := range m { + switch val := v.(type) { + case string: + res[k] = r.substitute(val) + case map[string]interface{}: + res[k] = r.substituteMap(val) + case []interface{}: + res[k] = r.substituteSlice(val) + default: + res[k] = v + } + } + return res +} + +func (r *Runner) substituteSlice(s []interface{}) []interface{} { + res := make([]interface{}, len(s)) + for i, v := range s { + switch val := v.(type) { + case string: + res[i] = r.substitute(val) + case map[string]interface{}: + res[i] = r.substituteMap(val) + case []interface{}: + res[i] = r.substituteSlice(val) + default: + res[i] = v + } + } + return res +} + +func (r *Runner) PrintHelp() { + fmt.Printf("Hepi - REST API Tester\n") + fmt.Printf("https://github.com/mitjafelicijan/hepi\n\n") + fmt.Println("Available Environments:") + if r.Config.Environments.Kind == yaml.MappingNode { + for i := 0; i < len(r.Config.Environments.Content); i += 2 { + name := r.Config.Environments.Content[i].Value + fmt.Printf(" - %s\n", name) + } + } + + fmt.Println("\nAvailable Requests:") + requestsNode := r.Config.Requests + if requestsNode.Kind == yaml.MappingNode { + maxLen := 0 + for i := 0; i < len(requestsNode.Content); i += 2 { + name := requestsNode.Content[i].Value + if len(name) > maxLen { + maxLen = len(name) + } + } + + for i := 0; i < len(requestsNode.Content); i += 2 { + name := requestsNode.Content[i].Value + valNode := requestsNode.Content[i+1] + var req Request + _ = valNode.Decode(&req) + fmt.Printf(" - %-*s %s\n", maxLen+2, name, req.Description) + } + } + + fmt.Println("\nAvailable Groups:") + for name, reqs := range r.Config.Groups { + fmt.Printf(" - %s (%s)\n", name, strings.Join(reqs, ", ")) + } + + fmt.Printf("\nUsage:\n %s -env <environment> -file <file_path> -req <request1,request2,...> -group <group_name> -headers\n", os.Args[0]) +} + +func loadResults(envName string) map[string]interface{} { + allResults := make(map[string]map[string]interface{}) + data, err := os.ReadFile(resultsFile) + if err != nil { + return make(map[string]interface{}) + } + json.Unmarshal(data, &allResults) + + if res, ok := allResults[envName]; ok { + return res + } + return make(map[string]interface{}) +} + +func saveResults(envName string, results map[string]interface{}) { + allResults := make(map[string]map[string]interface{}) + data, err := os.ReadFile(resultsFile) + if err == nil { + json.Unmarshal(data, &allResults) + } + + allResults[envName] = results + + output, err := json.MarshalIndent(allResults, "", " ") + if err != nil { + log.Printf("failed to marshal results: %v", err) + return + } + err = os.WriteFile(resultsFile, output, 0644) + if err != nil { + log.Printf("failed to save results: %v", err) + } +} + +func getValueFromMap(data interface{}, path []string) string { + for _, part := range path { + if m, ok := data.(map[string]interface{}); ok { + data = m[part] + } else if s, ok := data.([]interface{}); ok { + idx, err := strconv.Atoi(part) + if err == nil && idx >= 0 && idx < len(s) { + data = s[idx] + } else { + return fmt.Sprintf("{{MISSING:%s}}", part) + } + } else { + return fmt.Sprintf("{{MISSING:%s}}", part) + } + } + return fmt.Sprintf("%v", data) +} + +func decodeRecursive(data interface{}) interface{} { + switch v := data.(type) { + case map[string]interface{}: + for k, val := range v { + v[k] = decodeRecursive(val) + } + return v + case []interface{}: + for i, val := range v { + v[i] = decodeRecursive(val) + } + return v + case string: + // Try to decode if it looks like JSON (object or array) + trimmed := strings.TrimSpace(v) + if (strings.HasPrefix(trimmed, "{") && strings.HasSuffix(trimmed, "}")) || + (strings.HasPrefix(trimmed, "[") && strings.HasSuffix(trimmed, "]")) { + var decoded interface{} + if err := json.Unmarshal([]byte(v), &decoded); err == nil { + return decodeRecursive(decoded) + } + } + return v + default: + return v + } +} |
