1// Copyright 2011 The Go Authors. All rights reserved.
  2// Use of this source code is governed by a BSD-style
  3// license that can be found in the LICENSE file.
  4
  5package ssh
  6
  7import (
  8	"context"
  9	"errors"
 10	"fmt"
 11	"io"
 12	"math/rand"
 13	"net"
 14	"net/netip"
 15	"strconv"
 16	"strings"
 17	"sync"
 18	"time"
 19)
 20
 21// Listen requests the remote peer open a listening socket on
 22// addr. Incoming connections will be available by calling Accept on
 23// the returned net.Listener. The listener must be serviced, or the
 24// SSH connection may hang.
 25// N must be "tcp", "tcp4", "tcp6", or "unix".
 26//
 27// If the address is a hostname, it is sent to the remote peer as-is, without
 28// being resolved locally, and the Listener Addr method will return a zero IP.
 29func (c *Client) Listen(n, addr string) (net.Listener, error) {
 30	switch n {
 31	case "tcp", "tcp4", "tcp6":
 32		host, portStr, err := net.SplitHostPort(addr)
 33		if err != nil {
 34			return nil, err
 35		}
 36		port, err := strconv.ParseInt(portStr, 10, 32)
 37		if err != nil {
 38			return nil, err
 39		}
 40		return c.listenTCPInternal(host, int(port))
 41	case "unix":
 42		return c.ListenUnix(addr)
 43	default:
 44		return nil, fmt.Errorf("ssh: unsupported protocol: %s", n)
 45	}
 46}
 47
 48// Automatic port allocation is broken with OpenSSH before 6.0. See
 49// also https://bugzilla.mindrot.org/show_bug.cgi?id=2017.  In
 50// particular, OpenSSH 5.9 sends a channelOpenMsg with port number 0,
 51// rather than the actual port number. This means you can never open
 52// two different listeners with auto allocated ports. We work around
 53// this by trying explicit ports until we succeed.
 54
 55const openSSHPrefix = "OpenSSH_"
 56
 57var portRandomizer = rand.New(rand.NewSource(time.Now().UnixNano()))
 58
 59// isBrokenOpenSSHVersion returns true if the given version string
 60// specifies a version of OpenSSH that is known to have a bug in port
 61// forwarding.
 62func isBrokenOpenSSHVersion(versionStr string) bool {
 63	i := strings.Index(versionStr, openSSHPrefix)
 64	if i < 0 {
 65		return false
 66	}
 67	i += len(openSSHPrefix)
 68	j := i
 69	for ; j < len(versionStr); j++ {
 70		if versionStr[j] < '0' || versionStr[j] > '9' {
 71			break
 72		}
 73	}
 74	version, _ := strconv.Atoi(versionStr[i:j])
 75	return version < 6
 76}
 77
 78// autoPortListenWorkaround simulates automatic port allocation by
 79// trying random ports repeatedly.
 80func (c *Client) autoPortListenWorkaround(laddr *net.TCPAddr) (net.Listener, error) {
 81	var sshListener net.Listener
 82	var err error
 83	const tries = 10
 84	for i := 0; i < tries; i++ {
 85		addr := *laddr
 86		addr.Port = 1024 + portRandomizer.Intn(60000)
 87		sshListener, err = c.ListenTCP(&addr)
 88		if err == nil {
 89			laddr.Port = addr.Port
 90			return sshListener, err
 91		}
 92	}
 93	return nil, fmt.Errorf("ssh: listen on random port failed after %d tries: %v", tries, err)
 94}
 95
 96// RFC 4254 7.1
 97type channelForwardMsg struct {
 98	addr  string
 99	rport uint32
100}
101
102// handleForwards starts goroutines handling forwarded connections.
103// It's called on first use by (*Client).ListenTCP to not launch
104// goroutines until needed.
105func (c *Client) handleForwards() {
106	go c.forwards.handleChannels(c.HandleChannelOpen("forwarded-tcpip"))
107	go c.forwards.handleChannels(c.HandleChannelOpen("forwarded-streamlocal@openssh.com"))
108}
109
110// ListenTCP requests the remote peer open a listening socket
111// on laddr. Incoming connections will be available by calling
112// Accept on the returned net.Listener.
113//
114// ListenTCP accepts an IP address, to provide a hostname use [Client.Listen]
115// with "tcp", "tcp4", or "tcp6" network instead.
116func (c *Client) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) {
117	c.handleForwardsOnce.Do(c.handleForwards)
118	if laddr.Port == 0 && isBrokenOpenSSHVersion(string(c.ServerVersion())) {
119		return c.autoPortListenWorkaround(laddr)
120	}
121
122	return c.listenTCPInternal(laddr.IP.String(), laddr.Port)
123}
124
125func (c *Client) listenTCPInternal(host string, port int) (net.Listener, error) {
126	c.handleForwardsOnce.Do(c.handleForwards)
127
128	m := channelForwardMsg{
129		host,
130		uint32(port),
131	}
132	// send message
133	ok, resp, err := c.SendRequest("tcpip-forward", true, Marshal(&m))
134	if err != nil {
135		return nil, err
136	}
137	if !ok {
138		return nil, errors.New("ssh: tcpip-forward request denied by peer")
139	}
140
141	// If the original port was 0, then the remote side will
142	// supply a real port number in the response.
143	if port == 0 {
144		var p struct {
145			Port uint32
146		}
147		if err := Unmarshal(resp, &p); err != nil {
148			return nil, err
149		}
150		port = int(p.Port)
151	}
152	// Construct a local address placeholder for the remote listener. If the
153	// original host is an IP address, preserve it so that Listener.Addr()
154	// reports the same IP. If the host is a hostname or cannot be parsed as an
155	// IP, fall back to IPv4zero. The port field is always set, even if the
156	// original port was 0, because in that case the remote server will assign
157	// one, allowing callers to determine which port was selected.
158	ip := net.IPv4zero
159	if parsed, err := netip.ParseAddr(host); err == nil {
160		ip = net.IP(parsed.AsSlice())
161	}
162	laddr := &net.TCPAddr{
163		IP:   ip,
164		Port: port,
165	}
166	addr := net.JoinHostPort(host, strconv.FormatInt(int64(port), 10))
167	ch := c.forwards.add("tcp", addr)
168
169	return &tcpListener{laddr, addr, c, ch}, nil
170}
171
172// forwardList stores a mapping between remote
173// forward requests and the tcpListeners.
174type forwardList struct {
175	sync.Mutex
176	entries []forwardEntry
177}
178
179// forwardEntry represents an established mapping of a laddr on a
180// remote ssh server to a channel connected to a tcpListener.
181type forwardEntry struct {
182	addr    string // host:port or socket path
183	network string // tcp or unix
184	c       chan forward
185}
186
187// forward represents an incoming forwarded tcpip connection. The
188// arguments to add/remove/lookup should be address as specified in
189// the original forward-request.
190type forward struct {
191	newCh NewChannel // the ssh client channel underlying this forward
192	raddr net.Addr   // the raddr of the incoming connection
193}
194
195func (l *forwardList) add(n, addr string) chan forward {
196	l.Lock()
197	defer l.Unlock()
198	f := forwardEntry{
199		addr:    addr,
200		network: n,
201		c:       make(chan forward, 1),
202	}
203	l.entries = append(l.entries, f)
204	return f.c
205}
206
207// See RFC 4254, section 7.2
208type forwardedTCPPayload struct {
209	Addr       string
210	Port       uint32
211	OriginAddr string
212	OriginPort uint32
213}
214
215// parseTCPAddr parses the originating address from the remote into a *net.TCPAddr.
216func parseTCPAddr(addr string, port uint32) (*net.TCPAddr, error) {
217	if port == 0 || port > 65535 {
218		return nil, fmt.Errorf("ssh: port number out of range: %d", port)
219	}
220	ip, err := netip.ParseAddr(addr)
221	if err != nil {
222		return nil, fmt.Errorf("ssh: cannot parse IP address %q", addr)
223	}
224	return &net.TCPAddr{IP: net.IP(ip.AsSlice()), Port: int(port)}, nil
225}
226
227func (l *forwardList) handleChannels(in <-chan NewChannel) {
228	for ch := range in {
229		var (
230			addr    string
231			network string
232			raddr   net.Addr
233			err     error
234		)
235		switch channelType := ch.ChannelType(); channelType {
236		case "forwarded-tcpip":
237			var payload forwardedTCPPayload
238			if err = Unmarshal(ch.ExtraData(), &payload); err != nil {
239				ch.Reject(ConnectionFailed, "could not parse forwarded-tcpip payload: "+err.Error())
240				continue
241			}
242
243			// RFC 4254 section 7.2 specifies that incoming addresses should
244			// list the address that was connected, in string format. It is the
245			// same address used in the tcpip-forward request. The originator
246			// address is an IP address instead.
247			addr = net.JoinHostPort(payload.Addr, strconv.FormatUint(uint64(payload.Port), 10))
248
249			raddr, err = parseTCPAddr(payload.OriginAddr, payload.OriginPort)
250			if err != nil {
251				ch.Reject(ConnectionFailed, err.Error())
252				continue
253			}
254			network = "tcp"
255		case "forwarded-streamlocal@openssh.com":
256			var payload forwardedStreamLocalPayload
257			if err = Unmarshal(ch.ExtraData(), &payload); err != nil {
258				ch.Reject(ConnectionFailed, "could not parse forwarded-streamlocal@openssh.com payload: "+err.Error())
259				continue
260			}
261			addr = payload.SocketPath
262			raddr = &net.UnixAddr{
263				Name: "@",
264				Net:  "unix",
265			}
266			network = "unix"
267		default:
268			panic(fmt.Errorf("ssh: unknown channel type %s", channelType))
269		}
270		if ok := l.forward(network, addr, raddr, ch); !ok {
271			// Section 7.2, implementations MUST reject spurious incoming
272			// connections.
273			ch.Reject(Prohibited, "no forward for address")
274			continue
275		}
276
277	}
278}
279
280// remove removes the forward entry, and the channel feeding its
281// listener.
282func (l *forwardList) remove(n, addr string) {
283	l.Lock()
284	defer l.Unlock()
285	for i, f := range l.entries {
286		if n == f.network && addr == f.addr {
287			l.entries = append(l.entries[:i], l.entries[i+1:]...)
288			close(f.c)
289			return
290		}
291	}
292}
293
294// closeAll closes and clears all forwards.
295func (l *forwardList) closeAll() {
296	l.Lock()
297	defer l.Unlock()
298	for _, f := range l.entries {
299		close(f.c)
300	}
301	l.entries = nil
302}
303
304func (l *forwardList) forward(n, addr string, raddr net.Addr, ch NewChannel) bool {
305	l.Lock()
306	defer l.Unlock()
307	for _, f := range l.entries {
308		if n == f.network && addr == f.addr {
309			f.c <- forward{newCh: ch, raddr: raddr}
310			return true
311		}
312	}
313	return false
314}
315
316type tcpListener struct {
317	laddr *net.TCPAddr
318	addr  string
319
320	conn *Client
321	in   <-chan forward
322}
323
324// Accept waits for and returns the next connection to the listener.
325func (l *tcpListener) Accept() (net.Conn, error) {
326	s, ok := <-l.in
327	if !ok {
328		return nil, io.EOF
329	}
330	ch, incoming, err := s.newCh.Accept()
331	if err != nil {
332		return nil, err
333	}
334	go DiscardRequests(incoming)
335
336	return &chanConn{
337		Channel: ch,
338		laddr:   l.laddr,
339		raddr:   s.raddr,
340	}, nil
341}
342
343// Close closes the listener.
344func (l *tcpListener) Close() error {
345	host, port, err := net.SplitHostPort(l.addr)
346	if err != nil {
347		return err
348	}
349	rport, err := strconv.ParseUint(port, 10, 32)
350	if err != nil {
351		return err
352	}
353	m := channelForwardMsg{
354		host,
355		uint32(rport),
356	}
357
358	// this also closes the listener.
359	l.conn.forwards.remove("tcp", l.addr)
360	ok, _, err := l.conn.SendRequest("cancel-tcpip-forward", true, Marshal(&m))
361	if err == nil && !ok {
362		err = errors.New("ssh: cancel-tcpip-forward failed")
363	}
364	return err
365}
366
367// Addr returns the listener's network address.
368func (l *tcpListener) Addr() net.Addr {
369	return l.laddr
370}
371
372// DialContext initiates a connection to the addr from the remote host.
373//
374// The provided Context must be non-nil. If the context expires before the
375// connection is complete, an error is returned. Once successfully connected,
376// any expiration of the context will not affect the connection.
377//
378// See func Dial for additional information.
379func (c *Client) DialContext(ctx context.Context, n, addr string) (net.Conn, error) {
380	if err := ctx.Err(); err != nil {
381		return nil, err
382	}
383	type connErr struct {
384		conn net.Conn
385		err  error
386	}
387	ch := make(chan connErr)
388	go func() {
389		conn, err := c.Dial(n, addr)
390		select {
391		case ch <- connErr{conn, err}:
392		case <-ctx.Done():
393			if conn != nil {
394				conn.Close()
395			}
396		}
397	}()
398	select {
399	case res := <-ch:
400		return res.conn, res.err
401	case <-ctx.Done():
402		return nil, ctx.Err()
403	}
404}
405
406// Dial initiates a connection to the addr from the remote host.
407// The resulting connection has a zero LocalAddr() and RemoteAddr().
408func (c *Client) Dial(n, addr string) (net.Conn, error) {
409	var ch Channel
410	switch n {
411	case "tcp", "tcp4", "tcp6":
412		// Parse the address into host and numeric port.
413		host, portString, err := net.SplitHostPort(addr)
414		if err != nil {
415			return nil, err
416		}
417		port, err := strconv.ParseUint(portString, 10, 16)
418		if err != nil {
419			return nil, err
420		}
421		ch, err = c.dial(net.IPv4zero.String(), 0, host, int(port))
422		if err != nil {
423			return nil, err
424		}
425		// Use a zero address for local and remote address.
426		zeroAddr := &net.TCPAddr{
427			IP:   net.IPv4zero,
428			Port: 0,
429		}
430		return &chanConn{
431			Channel: ch,
432			laddr:   zeroAddr,
433			raddr:   zeroAddr,
434		}, nil
435	case "unix":
436		var err error
437		ch, err = c.dialStreamLocal(addr)
438		if err != nil {
439			return nil, err
440		}
441		return &chanConn{
442			Channel: ch,
443			laddr: &net.UnixAddr{
444				Name: "@",
445				Net:  "unix",
446			},
447			raddr: &net.UnixAddr{
448				Name: addr,
449				Net:  "unix",
450			},
451		}, nil
452	default:
453		return nil, fmt.Errorf("ssh: unsupported protocol: %s", n)
454	}
455}
456
457// DialTCP connects to the remote address raddr on the network net,
458// which must be "tcp", "tcp4", or "tcp6".  If laddr is not nil, it is used
459// as the local address for the connection.
460func (c *Client) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error) {
461	if laddr == nil {
462		laddr = &net.TCPAddr{
463			IP:   net.IPv4zero,
464			Port: 0,
465		}
466	}
467	ch, err := c.dial(laddr.IP.String(), laddr.Port, raddr.IP.String(), raddr.Port)
468	if err != nil {
469		return nil, err
470	}
471	return &chanConn{
472		Channel: ch,
473		laddr:   laddr,
474		raddr:   raddr,
475	}, nil
476}
477
478// RFC 4254 7.2
479type channelOpenDirectMsg struct {
480	raddr string
481	rport uint32
482	laddr string
483	lport uint32
484}
485
486func (c *Client) dial(laddr string, lport int, raddr string, rport int) (Channel, error) {
487	msg := channelOpenDirectMsg{
488		raddr: raddr,
489		rport: uint32(rport),
490		laddr: laddr,
491		lport: uint32(lport),
492	}
493	ch, in, err := c.OpenChannel("direct-tcpip", Marshal(&msg))
494	if err != nil {
495		return nil, err
496	}
497	go DiscardRequests(in)
498	return ch, nil
499}
500
501type tcpChan struct {
502	Channel // the backing channel
503}
504
505// chanConn fulfills the net.Conn interface without
506// the tcpChan having to hold laddr or raddr directly.
507type chanConn struct {
508	Channel
509	laddr, raddr net.Addr
510}
511
512// LocalAddr returns the local network address.
513func (t *chanConn) LocalAddr() net.Addr {
514	return t.laddr
515}
516
517// RemoteAddr returns the remote network address.
518func (t *chanConn) RemoteAddr() net.Addr {
519	return t.raddr
520}
521
522// SetDeadline sets the read and write deadlines associated
523// with the connection.
524func (t *chanConn) SetDeadline(deadline time.Time) error {
525	if err := t.SetReadDeadline(deadline); err != nil {
526		return err
527	}
528	return t.SetWriteDeadline(deadline)
529}
530
531// SetReadDeadline sets the read deadline.
532// A zero value for t means Read will not time out.
533// After the deadline, the error from Read will implement net.Error
534// with Timeout() == true.
535func (t *chanConn) SetReadDeadline(deadline time.Time) error {
536	// for compatibility with previous version,
537	// the error message contains "tcpChan"
538	return errors.New("ssh: tcpChan: deadline not supported")
539}
540
541// SetWriteDeadline exists to satisfy the net.Conn interface
542// but is not implemented by this type.  It always returns an error.
543func (t *chanConn) SetWriteDeadline(deadline time.Time) error {
544	return errors.New("ssh: tcpChan: deadline not supported")
545}