1package util
  2
  3import (
  4	"errors"
  5	"io"
  6	"os"
  7	"path/filepath"
  8	"strconv"
  9	"sync"
 10	"time"
 11
 12	"github.com/go-git/go-billy/v5"
 13)
 14
 15// RemoveAll removes path and any children it contains. It removes everything it
 16// can but returns the first error it encounters. If the path does not exist,
 17// RemoveAll returns nil (no error).
 18func RemoveAll(fs billy.Basic, path string) error {
 19	if r, ok := fs.(removerAll); ok {
 20		return r.RemoveAll(path)
 21	}
 22
 23	return removeAll(fs, path)
 24}
 25
 26type removerAll interface {
 27	RemoveAll(string) error
 28}
 29
 30func removeAll(fs billy.Basic, path string) error {
 31	// This implementation is adapted from os.RemoveAll.
 32
 33	// Simple case: if Remove works, we're done.
 34	err := fs.Remove(path)
 35	if err == nil || errors.Is(err, os.ErrNotExist) {
 36		return nil
 37	}
 38
 39	// Otherwise, is this a directory we need to recurse into?
 40	dir, serr := lstat(fs, path)
 41	if serr != nil {
 42		if errors.Is(serr, os.ErrNotExist) {
 43			return nil
 44		}
 45
 46		return serr
 47	}
 48
 49	if dir.Mode()&os.ModeSymlink != 0 || !dir.IsDir() {
 50		// Not a directory we should recurse into; return the error from Remove.
 51		return err
 52	}
 53
 54	dirfs, ok := fs.(billy.Dir)
 55	if !ok {
 56		return billy.ErrNotSupported
 57	}
 58
 59	// Directory.
 60	fis, err := dirfs.ReadDir(path)
 61	if err != nil {
 62		if errors.Is(err, os.ErrNotExist) {
 63			// Race. It was deleted between the Lstat and ReadDir.
 64			// Return nil per RemoveAll's docs.
 65			return nil
 66		}
 67
 68		return err
 69	}
 70
 71	// Remove contents & return first error.
 72	err = nil
 73	for _, fi := range fis {
 74		cpath := fs.Join(path, fi.Name())
 75		err1 := removeAll(fs, cpath)
 76		if err == nil {
 77			err = err1
 78		}
 79	}
 80
 81	// Remove directory.
 82	err1 := fs.Remove(path)
 83	if err1 == nil || errors.Is(err1, os.ErrNotExist) {
 84		return nil
 85	}
 86
 87	if err == nil {
 88		err = err1
 89	}
 90
 91	return err
 92}
 93
 94func lstat(filesystem billy.Basic, path string) (os.FileInfo, error) {
 95	if sl, ok := filesystem.(billy.Symlink); ok {
 96		// Avoid following a symlink substituted after the initial Remove fails.
 97		fi, err := sl.Lstat(path)
 98		if err == nil || !errors.Is(err, billy.ErrNotSupported) {
 99			return fi, err
100		}
101	}
102
103	return filesystem.Stat(path)
104}
105
106// WriteFile writes data to a file named by filename in the given filesystem.
107// If the file does not exist, WriteFile creates it with permissions perm;
108// otherwise WriteFile truncates it before writing.
109func WriteFile(fs billy.Basic, filename string, data []byte, perm os.FileMode) (err error) {
110	f, err := fs.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, perm)
111	if err != nil {
112		return err
113	}
114	defer func() {
115		if f != nil {
116			err1 := f.Close()
117			if err == nil {
118				err = err1
119			}
120		}
121	}()
122
123	n, err := f.Write(data)
124	if err == nil && n < len(data) {
125		err = io.ErrShortWrite
126	}
127
128	return nil
129}
130
131// Random number state.
132// We generate random temporary file names so that there's a good
133// chance the file doesn't exist yet - keeps the number of tries in
134// TempFile to a minimum.
135var (
136	rand   uint32
137	randmu sync.Mutex
138)
139
140func reseed() uint32 {
141	return uint32(time.Now().UnixNano() + int64(os.Getpid()))
142}
143
144func nextSuffix() string {
145	randmu.Lock()
146	r := rand
147	if r == 0 {
148		r = reseed()
149	}
150	r = r*1664525 + 1013904223 // constants from Numerical Recipes
151	rand = r
152	randmu.Unlock()
153	return strconv.Itoa(int(1e9 + r%1e9))[1:]
154}
155
156// TempFile creates a new temporary file in the directory dir with a name
157// beginning with prefix, opens the file for reading and writing, and returns
158// the resulting *os.File. If dir is the empty string, TempFile uses the default
159// directory for temporary files (see os.TempDir). Multiple programs calling
160// TempFile simultaneously will not choose the same file. The caller can use
161// f.Name() to find the pathname of the file. It is the caller's responsibility
162// to remove the file when no longer needed.
163func TempFile(fs billy.Basic, dir, prefix string) (f billy.File, err error) {
164	// This implementation is based on stdlib ioutil.TempFile.
165	if dir == "" {
166		dir = getTempDir(fs)
167	}
168
169	nconflict := 0
170	for i := 0; i < 10000; i++ {
171		name := filepath.Join(dir, prefix+nextSuffix())
172		f, err = fs.OpenFile(name, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0600)
173		if errors.Is(err, os.ErrExist) {
174			if nconflict++; nconflict > 10 {
175				randmu.Lock()
176				rand = reseed()
177				randmu.Unlock()
178			}
179			continue
180		}
181		break
182	}
183	return
184}
185
186// TempDir creates a new temporary directory in the directory dir
187// with a name beginning with prefix and returns the path of the
188// new directory. If dir is the empty string, TempDir uses the
189// default directory for temporary files (see os.TempDir).
190// Multiple programs calling TempDir simultaneously
191// will not choose the same directory. It is the caller's responsibility
192// to remove the directory when no longer needed.
193func TempDir(fs billy.Dir, dir, prefix string) (name string, err error) {
194	// This implementation is based on stdlib ioutil.TempDir
195
196	if dir == "" {
197		dir = getTempDir(fs.(billy.Basic))
198	}
199
200	nconflict := 0
201	for i := 0; i < 10000; i++ {
202		try := filepath.Join(dir, prefix+nextSuffix())
203		err = fs.MkdirAll(try, 0700)
204		if errors.Is(err, os.ErrExist) {
205			if nconflict++; nconflict > 10 {
206				randmu.Lock()
207				rand = reseed()
208				randmu.Unlock()
209			}
210			continue
211		}
212		if errors.Is(err, os.ErrNotExist) {
213			if _, err := os.Stat(dir); errors.Is(err, os.ErrNotExist) {
214				return "", err
215			}
216		}
217		if err == nil {
218			name = try
219		}
220		break
221	}
222	return
223}
224
225func getTempDir(fs billy.Basic) string {
226	ch, ok := fs.(billy.Chroot)
227	if !ok || ch.Root() == "" || ch.Root() == "/" || ch.Root() == string(filepath.Separator) {
228		return os.TempDir()
229	}
230
231	return ".tmp"
232}
233
234// ReadFile reads the named file and returns the contents from the given filesystem.
235// A successful call returns err == nil, not err == EOF.
236// Because ReadFile reads the whole file, it does not treat an EOF from Read
237// as an error to be reported.
238func ReadFile(fs billy.Basic, name string) ([]byte, error) {
239	f, err := fs.Open(name)
240	if err != nil {
241		return nil, err
242	}
243
244	defer f.Close()
245
246	var size int
247	if info, err := fs.Stat(name); err == nil {
248		size64 := info.Size()
249		if int64(int(size64)) == size64 {
250			size = int(size64)
251		}
252	}
253
254	size++ // one byte for final read at EOF
255	// If a file claims a small size, read at least 512 bytes.
256	// In particular, files in Linux's /proc claim size 0 but
257	// then do not work right if read in small pieces,
258	// so an initial read of 1 byte would not work correctly.
259
260	if size < 512 {
261		size = 512
262	}
263
264	data := make([]byte, 0, size)
265	for {
266		if len(data) >= cap(data) {
267			d := append(data[:cap(data)], 0)
268			data = d[:len(data)]
269		}
270
271		n, err := f.Read(data[len(data):cap(data)])
272		data = data[:len(data)+n]
273
274		if err != nil {
275			if errors.Is(err, io.EOF) {
276				err = nil
277			}
278
279			return data, err
280		}
281	}
282}