1package main
  2
  3import (
  4	"bytes"
  5	"encoding/json"
  6	"flag"
  7	"fmt"
  8	"io"
  9	"log"
 10	"math/rand"
 11	"mime/multipart"
 12	"net/http"
 13	"net/url"
 14	"os"
 15	"regexp"
 16	"strconv"
 17	"strings"
 18	"time"
 19
 20	"github.com/joho/godotenv"
 21	"github.com/mattn/go-colorable"
 22	"github.com/neilotoole/jsoncolor"
 23	"gopkg.in/yaml.v3"
 24)
 25
 26const (
 27	colorReset  = "\033[0m"
 28	colorRed    = "\033[31m"
 29	colorGreen  = "\033[32m"
 30	colorYellow = "\033[33m"
 31	colorBlue   = "\033[34m"
 32	colorPurple = "\033[35m"
 33	colorCyan   = "\033[36m"
 34	colorWhite  = "\033[37m"
 35	colorBold   = "\033[1m"
 36)
 37
 38// Config represents the Hepi configuration file structure.
 39type Config struct {
 40	Environments yaml.Node           `yaml:"environments"`
 41	Requests     yaml.Node           `yaml:"requests"`
 42	Groups       map[string][]string `yaml:"groups"`
 43}
 44
 45// Request represents an individual API request definition.
 46type Request struct {
 47	Method      string                 `yaml:"method"`
 48	URL         string                 `yaml:"url"`
 49	Description string                 `yaml:"description"`
 50	Headers     map[string]string      `yaml:"headers"`
 51	Params      map[string]interface{} `yaml:"params"`
 52	JSON        map[string]interface{} `yaml:"json"`
 53	Form        map[string]interface{} `yaml:"form"`
 54	Files       map[string]string      `yaml:"files"`
 55}
 56
 57// Runner manages the execution of API requests.
 58type Runner struct {
 59	Config      Config
 60	EnvName     string
 61	Environment map[string]interface{}
 62	State       map[string]interface{}
 63	HTTPClient  *http.Client
 64	ShowHeaders bool
 65	StateFile   string
 66}
 67
 68func main() {
 69	godotenv.Load()
 70
 71	var envName string
 72	flag.StringVar(&envName, "env", "", "Environment to use")
 73
 74	var filePath string
 75	flag.StringVar(&filePath, "file", "", "Path to the YAML file")
 76
 77	var statePath string
 78	flag.StringVar(&statePath, "state", ".hepi.json", "Path to state file")
 79	reqNames := flag.String("req", "", "Comma-separated list of request names to execute")
 80	groupName := flag.String("group", "", "Group to execute")
 81	showHeaders := flag.Bool("headers", false, "Display response headers")
 82	timeout := flag.Duration("timeout", 10*time.Second, "Request timeout duration")
 83	flag.Parse()
 84
 85	if filePath == "" {
 86		fmt.Printf("Error: -file is required\n\n")
 87		fmt.Printf("Usage: %s -env <environment> -file <file_path> [options]\n", os.Args[0])
 88		os.Exit(1)
 89	}
 90
 91	runner, err := NewRunner(filePath, envName, statePath, *timeout)
 92	if err != nil {
 93		log.Fatalf("Error: %v", err)
 94	}
 95	runner.ShowHeaders = *showHeaders
 96
 97	if *groupName == "" && *reqNames == "" {
 98		fmt.Printf("Error: -group or -req is required\n\n")
 99		runner.PrintHelp()
100		os.Exit(0)
101	}
102
103	if envName == "" {
104		runner.PrintHelp()
105		return
106	}
107
108	if *groupName != "" {
109		if err := runner.ExecuteGroup(*groupName); err != nil {
110			log.Fatalf("Error: %v", err)
111		}
112	}
113
114	if *reqNames != "" {
115		if err := runner.ExecuteRequests(*reqNames); err != nil {
116			log.Fatalf("Error: %v", err)
117		}
118	}
119}
120
121// NewRunner initializes a new Hepi runner.
122func NewRunner(filePath, envName, stateFile string, timeout time.Duration) (*Runner, error) {
123	data, err := os.ReadFile(filePath)
124	if err != nil {
125		return nil, fmt.Errorf("%sfailed to read file: %w%s", colorRed, err, colorReset)
126	}
127
128	var config Config
129	if err := yaml.Unmarshal(data, &config); err != nil {
130		return nil, fmt.Errorf("%sfailed to parse YAML: %w%s", colorRed, err, colorReset)
131	}
132
133	if config.Environments.Kind != yaml.MappingNode {
134		return nil, fmt.Errorf("%senvironments must be a mapping%s", colorRed, colorReset)
135	}
136
137	selectedEnvName := envName
138	var selectedEnv map[string]interface{}
139
140	if envName != "" {
141		found := false
142		var availableEnvs []string
143		for i := 0; i < len(config.Environments.Content); i += 2 {
144			name := config.Environments.Content[i].Value
145			availableEnvs = append(availableEnvs, name)
146			if name == envName {
147				if err := config.Environments.Content[i+1].Decode(&selectedEnv); err != nil {
148					return nil, fmt.Errorf("%sfailed to decode environment %q: %w%s", colorRed, envName, err, colorReset)
149				}
150				found = true
151				break
152			}
153		}
154		if !found {
155			return nil, fmt.Errorf("%senvironment %q not found\nAvailable environments:\n- %s%s", colorRed, envName, strings.Join(availableEnvs, "\n- "), colorReset)
156		}
157	}
158
159	return &Runner{
160		Config:      config,
161		EnvName:     selectedEnvName,
162		Environment: selectedEnv,
163		State:       loadState(selectedEnvName, stateFile),
164		StateFile:   stateFile,
165		HTTPClient:  &http.Client{Timeout: timeout},
166	}, nil
167}
168
169// ExecuteGroup runs all requests in the specified group.
170func (r *Runner) ExecuteGroup(groupName string) error {
171	group, ok := r.Config.Groups[groupName]
172	if !ok {
173		return fmt.Errorf("%sgroup %q not found%s", colorRed, groupName, colorReset)
174	}
175
176	for _, reqName := range group {
177		if err := r.ExecuteRequests(reqName); err != nil {
178			return err
179		}
180	}
181
182	return nil
183}
184
185// ExecuteRequests runs the specified requests.
186func (r *Runner) ExecuteRequests(reqNames string) error {
187	filter := make(map[string]bool)
188	for _, name := range strings.Split(reqNames, ",") {
189		filter[strings.TrimSpace(name)] = true
190	}
191
192	// Validate that all requested requests exist
193	foundRequests := make(map[string]bool)
194
195	requestsNode := r.Config.Requests
196	if requestsNode.Kind != yaml.MappingNode {
197		return fmt.Errorf("%srequests must be a mapping%s", colorRed, colorReset)
198	}
199
200	for i := 0; i < len(requestsNode.Content); i += 2 {
201		nameNode := requestsNode.Content[i]
202		valNode := requestsNode.Content[i+1]
203
204		name := nameNode.Value
205		if !filter[name] {
206			continue
207		}
208		foundRequests[name] = true
209
210		var req Request
211		if err := valNode.Decode(&req); err != nil {
212			if strings.Contains(err.Error(), "invalid map key") {
213				return fmt.Errorf("%sfailed to decode request %q: %w\n%sHint: Check for unquoted template variables like {{foo}} used as values%s", colorRed, name, err, colorYellow, colorReset)
214			}
215			return fmt.Errorf("%sfailed to decode request %q: %w%s", colorRed, name, err, colorReset)
216		}
217
218		fmt.Printf("\n%s--- %s[%s]%s %s ---%s\n", colorBold, colorCyan, name, colorReset, req.Description, colorReset)
219		if err := r.executeRequest(name, req); err != nil {
220			return err
221		}
222	}
223
224	var missing []string
225	for req := range filter {
226		if !foundRequests[req] {
227			missing = append(missing, req)
228		}
229	}
230	if len(missing) > 0 {
231		return fmt.Errorf("%srequests not found: %s%s", colorRed, strings.Join(missing, ", "), colorReset)
232	}
233
234	return nil
235}
236
237func (r *Runner) executeRequest(name string, req Request) error {
238	rawURL := r.substitute(req.URL)
239
240	// Handle query parameters
241	if req.Params != nil {
242		u, err := url.Parse(rawURL)
243		if err != nil {
244			return fmt.Errorf("%sfailed to parse URL %q: %w%s", colorRed, rawURL, err, colorReset)
245		}
246		q := u.Query()
247		params := r.substituteMap(req.Params)
248		for k, v := range params {
249			q.Set(k, fmt.Sprintf("%v", v))
250		}
251		u.RawQuery = q.Encode()
252		rawURL = u.String()
253	}
254
255	methodColor := colorCyan
256	switch req.Method {
257	case "GET":
258		methodColor = colorGreen
259	case "POST":
260		methodColor = colorCyan
261	case "PUT", "PATCH":
262		methodColor = colorYellow
263	case "DELETE":
264		methodColor = colorRed
265	}
266
267	fmt.Printf("%s%s%s %s\n", methodColor, req.Method, colorReset, rawURL)
268
269	var bodyReader io.Reader
270	var contentType string
271
272	if req.JSON != nil {
273		jsonBody := r.substituteMap(req.JSON)
274		data, _ := json.Marshal(jsonBody)
275		bodyReader = bytes.NewReader(data)
276		contentType = "application/json"
277	} else if req.Files != nil {
278		body := &bytes.Buffer{}
279		writer := multipart.NewWriter(body)
280
281		// Add form fields
282		if req.Form != nil {
283			form := r.substituteMap(req.Form)
284			for k, v := range form {
285				_ = writer.WriteField(k, fmt.Sprintf("%v", v))
286			}
287		}
288
289		// Add files
290		for field, path := range req.Files {
291			substitutedPath := r.substitute(path)
292			file, err := os.Open(substitutedPath)
293			if err != nil {
294				return fmt.Errorf("%sfailed to open file %q: %w%s", colorRed, substitutedPath, err, colorReset)
295			}
296			defer file.Close()
297
298			part, err := writer.CreateFormFile(field, substitutedPath)
299			if err != nil {
300				return fmt.Errorf("%sfailed to create form file for %q: %w%s", colorRed, field, err, colorReset)
301			}
302			_, _ = io.Copy(part, file)
303		}
304
305		writer.Close()
306		bodyReader = body
307		contentType = writer.FormDataContentType()
308	} else if req.Form != nil {
309		formData := url.Values{}
310		form := r.substituteMap(req.Form)
311		for k, v := range form {
312			formData.Set(k, fmt.Sprintf("%v", v))
313		}
314		bodyReader = strings.NewReader(formData.Encode())
315		contentType = "application/x-www-form-urlencoded"
316	}
317
318	httpReq, err := http.NewRequest(req.Method, rawURL, bodyReader)
319	if err != nil {
320		return fmt.Errorf("%sfailed to create HTTP request: %w%s", colorRed, err, colorReset)
321	}
322
323	if contentType != "" {
324		httpReq.Header.Set("Content-Type", contentType)
325	}
326
327	for k, v := range req.Headers {
328		httpReq.Header.Set(k, r.substitute(v))
329	}
330
331	startTime := time.Now()
332	resp, err := r.HTTPClient.Do(httpReq)
333	if err != nil {
334		if os.IsTimeout(err) {
335			return fmt.Errorf("%srequest timed out after %v%s", colorRed, r.HTTPClient.Timeout, colorReset)
336		}
337		return fmt.Errorf("%srequest failed: %w%s", colorRed, err, colorReset)
338	}
339	duration := time.Since(startTime)
340	defer resp.Body.Close()
341
342	statusColor := colorRed
343	if resp.StatusCode >= 200 && resp.StatusCode < 300 {
344		statusColor = colorGreen
345	} else if resp.StatusCode >= 300 && resp.StatusCode < 400 {
346		statusColor = colorYellow
347	}
348
349	fmt.Printf("Status: %s%s%s (took %s%v%s)\n", statusColor, resp.Status, colorReset, colorYellow, duration.Round(time.Millisecond), colorReset)
350
351	if r.ShowHeaders {
352		fmt.Printf("\n%sHeaders:%s\n", colorBold, colorReset)
353		for k, v := range resp.Header {
354			fmt.Printf("  %s%s%s: %s\n", colorCyan, k, colorReset, strings.Join(v, ", "))
355		}
356	}
357
358	respData, err := io.ReadAll(resp.Body)
359	if err != nil {
360		return fmt.Errorf("%sfailed to read response body: %w%s", colorRed, err, colorReset)
361	}
362
363	if len(respData) > 0 {
364		var result interface{}
365		if err := json.Unmarshal(respData, &result); err == nil {
366			result = decodeRecursive(result)
367			r.State[name] = result
368			r.saveState()
369			fmt.Printf("\n%sResponse:%s\n", colorBold, colorReset)
370
371			var enc *jsoncolor.Encoder
372			if jsoncolor.IsColorTerminal(os.Stdout) {
373				out := colorable.NewColorable(os.Stdout)
374				enc = jsoncolor.NewEncoder(out)
375				enc.SetColors(jsoncolor.DefaultColors())
376			} else {
377				enc = jsoncolor.NewEncoder(os.Stdout)
378			}
379
380			enc.SetIndent("", "  ")
381			if err := enc.Encode(result); err != nil {
382				fmt.Println(result)
383			}
384		} else {
385			fmt.Printf("\n%sResponse (non-JSON):%s\n", colorBold, colorReset)
386			fmt.Println(string(respData))
387		}
388	}
389
390	return nil
391}
392
393func (r *Runner) substitute(s string) string {
394	// 1. Handle [[dynamic]] placeholders using the Generators map and oneof support
395	genRegex := regexp.MustCompile(`\[\[(.*?)\]\]`)
396	s = genRegex.ReplaceAllStringFunc(s, func(match string) string {
397		tag := strings.Trim(match[2:len(match)-2], " ")
398
399		// Handle [[oneof: a, b, c]]
400		if strings.HasPrefix(tag, "oneof:") {
401			parts := strings.Split(tag[6:], ",")
402			if len(parts) > 0 {
403				return strings.TrimSpace(parts[rand.Intn(len(parts))])
404			}
405		}
406
407		// Handle Generators map
408		if gen, ok := Generators[tag]; ok {
409			return gen()
410		}
411
412		// Fallback for random_ prefix if not already present
413		if gen, ok := Generators["random_"+tag]; ok {
414			return gen()
415		}
416
417		return match
418	})
419
420	// 2. Handle {{variables}}
421	re := regexp.MustCompile(`{{(.*?)}}`)
422	return re.ReplaceAllStringFunc(s, func(match string) string {
423		key := strings.Trim(match[2:len(match)-2], " ")
424
425		// Priority 1: System Environment Variables
426		if val, exists := os.LookupEnv(key); exists {
427			return val
428		}
429
430		// Priority 2: Config Environment Variables
431		if val, ok := r.Environment[key]; ok {
432			return fmt.Sprintf("%v", val)
433		}
434
435		// Priority 3: Previous Request Results
436		parts := strings.Split(key, ".")
437		if len(parts) > 1 {
438			if res, ok := r.State[parts[0]]; ok {
439				return getValueFromMap(res, parts[1:])
440			}
441		}
442
443		return match
444	})
445}
446
447func (r *Runner) substituteMap(m map[string]interface{}) map[string]interface{} {
448	res := make(map[string]interface{})
449	for k, v := range m {
450		switch val := v.(type) {
451		case string:
452			res[k] = r.substitute(val)
453		case map[string]interface{}:
454			res[k] = r.substituteMap(val)
455		case []interface{}:
456			res[k] = r.substituteSlice(val)
457		default:
458			res[k] = v
459		}
460	}
461	return res
462}
463
464func (r *Runner) substituteSlice(s []interface{}) []interface{} {
465	res := make([]interface{}, len(s))
466	for i, v := range s {
467		switch val := v.(type) {
468		case string:
469			res[i] = r.substitute(val)
470		case map[string]interface{}:
471			res[i] = r.substituteMap(val)
472		case []interface{}:
473			res[i] = r.substituteSlice(val)
474		default:
475			res[i] = v
476		}
477	}
478	return res
479}
480
481func (r *Runner) PrintHelp() {
482	fmt.Printf("Hepi - REST API Tester\n")
483	fmt.Printf("https://github.com/mitjafelicijan/hepi\n\n")
484	fmt.Println("Available Environments:")
485	if r.Config.Environments.Kind == yaml.MappingNode {
486		for i := 0; i < len(r.Config.Environments.Content); i += 2 {
487			name := r.Config.Environments.Content[i].Value
488			fmt.Printf("  - %s\n", name)
489		}
490	}
491
492	fmt.Println("\nAvailable Requests:")
493	requestsNode := r.Config.Requests
494	if requestsNode.Kind == yaml.MappingNode {
495		maxLen := 0
496		for i := 0; i < len(requestsNode.Content); i += 2 {
497			name := requestsNode.Content[i].Value
498			if len(name) > maxLen {
499				maxLen = len(name)
500			}
501		}
502
503		for i := 0; i < len(requestsNode.Content); i += 2 {
504			name := requestsNode.Content[i].Value
505			valNode := requestsNode.Content[i+1]
506			var req Request
507			_ = valNode.Decode(&req)
508			fmt.Printf("  - %-*s  %s\n", maxLen+2, name, req.Description)
509		}
510	}
511
512	fmt.Println("\nAvailable Groups:")
513	for name, reqs := range r.Config.Groups {
514		fmt.Printf("  - %s (%s)\n", name, strings.Join(reqs, ", "))
515	}
516
517	fmt.Printf("\nUsage:\n  %s -env <environment> -file <file_path> -req <request1,request2,...> -group <group_name> -headers\n", os.Args[0])
518}
519
520func loadState(envName, stateFile string) map[string]interface{} {
521	allStates := make(map[string]map[string]interface{})
522	data, err := os.ReadFile(stateFile)
523	if err != nil {
524		return make(map[string]interface{})
525	}
526	json.Unmarshal(data, &allStates)
527
528	if res, ok := allStates[envName]; ok {
529		return res
530	}
531	return make(map[string]interface{})
532}
533
534func (r *Runner) saveState() {
535	allStates := make(map[string]map[string]interface{})
536	data, err := os.ReadFile(r.StateFile)
537	if err == nil {
538		json.Unmarshal(data, &allStates)
539	}
540
541	allStates[r.EnvName] = r.State
542
543	output, err := json.MarshalIndent(allStates, "", "  ")
544	if err != nil {
545		log.Printf("failed to marshal state: %v", err)
546		return
547	}
548	err = os.WriteFile(r.StateFile, output, 0644)
549	if err != nil {
550		log.Printf("failed to save state: %v", err)
551	}
552}
553
554func getValueFromMap(data interface{}, path []string) string {
555	for _, part := range path {
556		if m, ok := data.(map[string]interface{}); ok {
557			data = m[part]
558		} else if s, ok := data.([]interface{}); ok {
559			idx, err := strconv.Atoi(part)
560			if err == nil && idx >= 0 && idx < len(s) {
561				data = s[idx]
562			} else {
563				return fmt.Sprintf("{{MISSING:%s}}", part)
564			}
565		} else {
566			return fmt.Sprintf("{{MISSING:%s}}", part)
567		}
568	}
569	return fmt.Sprintf("%v", data)
570}
571
572func decodeRecursive(data interface{}) interface{} {
573	switch v := data.(type) {
574	case map[string]interface{}:
575		for k, val := range v {
576			v[k] = decodeRecursive(val)
577		}
578		return v
579	case []interface{}:
580		for i, val := range v {
581			v[i] = decodeRecursive(val)
582		}
583		return v
584	case string:
585		// Try to decode if it looks like JSON (object or array)
586		trimmed := strings.TrimSpace(v)
587		if (strings.HasPrefix(trimmed, "{") && strings.HasSuffix(trimmed, "}")) ||
588			(strings.HasPrefix(trimmed, "[") && strings.HasSuffix(trimmed, "]")) {
589			var decoded interface{}
590			if err := json.Unmarshal([]byte(v), &decoded); err == nil {
591				return decodeRecursive(decoded)
592			}
593		}
594		return v
595	default:
596		return v
597	}
598}