1// Copyright 2011 The Go Authors. All rights reserved.
  2// Use of this source code is governed by a BSD-style
  3// license that can be found in the LICENSE file.
  4
  5package ssh
  6
  7import (
  8	"bufio"
  9	"bytes"
 10	"errors"
 11	"fmt"
 12	"io"
 13	"log"
 14)
 15
 16// debugTransport if set, will print packet types as they go over the
 17// wire. No message decoding is done, to minimize the impact on timing.
 18const debugTransport = false
 19
 20// packetConn represents a transport that implements packet based
 21// operations.
 22type packetConn interface {
 23	// Encrypt and send a packet of data to the remote peer.
 24	writePacket(packet []byte) error
 25
 26	// Read a packet from the connection. The read is blocking,
 27	// i.e. if error is nil, then the returned byte slice is
 28	// always non-empty.
 29	readPacket() ([]byte, error)
 30
 31	// Close closes the write-side of the connection.
 32	Close() error
 33}
 34
 35// transport is the keyingTransport that implements the SSH packet
 36// protocol.
 37type transport struct {
 38	reader connectionState
 39	writer connectionState
 40
 41	bufReader *bufio.Reader
 42	bufWriter *bufio.Writer
 43	rand      io.Reader
 44	isClient  bool
 45	io.Closer
 46
 47	strictMode     bool
 48	initialKEXDone bool
 49}
 50
 51// packetCipher represents a combination of SSH encryption/MAC
 52// protocol.  A single instance should be used for one direction only.
 53type packetCipher interface {
 54	// writeCipherPacket encrypts the packet and writes it to w. The
 55	// contents of the packet are generally scrambled.
 56	writeCipherPacket(seqnum uint32, w io.Writer, rand io.Reader, packet []byte) error
 57
 58	// readCipherPacket reads and decrypts a packet of data. The
 59	// returned packet may be overwritten by future calls of
 60	// readPacket.
 61	readCipherPacket(seqnum uint32, r io.Reader) ([]byte, error)
 62}
 63
 64// connectionState represents one side (read or write) of the
 65// connection. This is necessary because each direction has its own
 66// keys, and can even have its own algorithms
 67type connectionState struct {
 68	packetCipher
 69	seqNum           uint32
 70	dir              direction
 71	pendingKeyChange chan packetCipher
 72}
 73
 74func (t *transport) setStrictMode() error {
 75	if t.reader.seqNum != 1 {
 76		return errors.New("ssh: sequence number != 1 when strict KEX mode requested")
 77	}
 78	t.strictMode = true
 79	return nil
 80}
 81
 82func (t *transport) setInitialKEXDone() {
 83	t.initialKEXDone = true
 84}
 85
 86// prepareKeyChange sets up key material for a keychange. The key changes in
 87// both directions are triggered by reading and writing a msgNewKey packet
 88// respectively.
 89func (t *transport) prepareKeyChange(algs *NegotiatedAlgorithms, kexResult *kexResult) error {
 90	ciph, err := newPacketCipher(t.reader.dir, algs.Read, kexResult)
 91	if err != nil {
 92		return err
 93	}
 94	t.reader.pendingKeyChange <- ciph
 95
 96	ciph, err = newPacketCipher(t.writer.dir, algs.Write, kexResult)
 97	if err != nil {
 98		return err
 99	}
100	t.writer.pendingKeyChange <- ciph
101
102	return nil
103}
104
105func (t *transport) printPacket(p []byte, write bool) {
106	if len(p) == 0 {
107		return
108	}
109	who := "server"
110	if t.isClient {
111		who = "client"
112	}
113	what := "read"
114	if write {
115		what = "write"
116	}
117
118	log.Println(what, who, p[0])
119}
120
121// Read and decrypt next packet.
122func (t *transport) readPacket() (p []byte, err error) {
123	for {
124		p, err = t.reader.readPacket(t.bufReader, t.strictMode)
125		if err != nil {
126			break
127		}
128		// in strict mode we pass through DEBUG and IGNORE packets only during the initial KEX
129		if len(p) == 0 || (t.strictMode && !t.initialKEXDone) || (p[0] != msgIgnore && p[0] != msgDebug) {
130			break
131		}
132	}
133	if debugTransport {
134		t.printPacket(p, false)
135	}
136
137	return p, err
138}
139
140func (s *connectionState) readPacket(r *bufio.Reader, strictMode bool) ([]byte, error) {
141	packet, err := s.packetCipher.readCipherPacket(s.seqNum, r)
142	s.seqNum++
143	if err == nil && len(packet) == 0 {
144		err = errors.New("ssh: zero length packet")
145	}
146
147	if len(packet) > 0 {
148		switch packet[0] {
149		case msgNewKeys:
150			select {
151			case cipher := <-s.pendingKeyChange:
152				s.packetCipher = cipher
153				if strictMode {
154					s.seqNum = 0
155				}
156			default:
157				return nil, errors.New("ssh: got bogus newkeys message")
158			}
159
160		case msgDisconnect:
161			// Transform a disconnect message into an
162			// error. Since this is lowest level at which
163			// we interpret message types, doing it here
164			// ensures that we don't have to handle it
165			// elsewhere.
166			var msg disconnectMsg
167			if err := Unmarshal(packet, &msg); err != nil {
168				return nil, err
169			}
170			return nil, &msg
171		}
172	}
173
174	// The packet may point to an internal buffer, so copy the
175	// packet out here.
176	fresh := make([]byte, len(packet))
177	copy(fresh, packet)
178
179	return fresh, err
180}
181
182func (t *transport) writePacket(packet []byte) error {
183	if debugTransport {
184		t.printPacket(packet, true)
185	}
186	return t.writer.writePacket(t.bufWriter, t.rand, packet, t.strictMode)
187}
188
189func (s *connectionState) writePacket(w *bufio.Writer, rand io.Reader, packet []byte, strictMode bool) error {
190	changeKeys := len(packet) > 0 && packet[0] == msgNewKeys
191
192	err := s.packetCipher.writeCipherPacket(s.seqNum, w, rand, packet)
193	if err != nil {
194		return err
195	}
196	if err = w.Flush(); err != nil {
197		return err
198	}
199	s.seqNum++
200	if changeKeys {
201		select {
202		case cipher := <-s.pendingKeyChange:
203			s.packetCipher = cipher
204			if strictMode {
205				s.seqNum = 0
206			}
207		default:
208			panic("ssh: no key material for msgNewKeys")
209		}
210	}
211	return err
212}
213
214func newTransport(rwc io.ReadWriteCloser, rand io.Reader, isClient bool) *transport {
215	t := &transport{
216		bufReader: bufio.NewReader(rwc),
217		bufWriter: bufio.NewWriter(rwc),
218		rand:      rand,
219		reader: connectionState{
220			packetCipher:     &streamPacketCipher{cipher: noneCipher{}},
221			pendingKeyChange: make(chan packetCipher, 1),
222		},
223		writer: connectionState{
224			packetCipher:     &streamPacketCipher{cipher: noneCipher{}},
225			pendingKeyChange: make(chan packetCipher, 1),
226		},
227		Closer: rwc,
228	}
229	t.isClient = isClient
230
231	if isClient {
232		t.reader.dir = serverKeys
233		t.writer.dir = clientKeys
234	} else {
235		t.reader.dir = clientKeys
236		t.writer.dir = serverKeys
237	}
238
239	return t
240}
241
242type direction struct {
243	ivTag     []byte
244	keyTag    []byte
245	macKeyTag []byte
246}
247
248var (
249	serverKeys = direction{[]byte{'B'}, []byte{'D'}, []byte{'F'}}
250	clientKeys = direction{[]byte{'A'}, []byte{'C'}, []byte{'E'}}
251)
252
253// setupKeys sets the cipher and MAC keys from kex.K, kex.H and sessionId, as
254// described in RFC 4253, section 6.4. direction should either be serverKeys
255// (to setup server->client keys) or clientKeys (for client->server keys).
256func newPacketCipher(d direction, algs DirectionAlgorithms, kex *kexResult) (packetCipher, error) {
257	cipherMode := cipherModes[algs.Cipher]
258	if cipherMode == nil {
259		return nil, fmt.Errorf("ssh: unsupported cipher %v", algs.Cipher)
260	}
261
262	iv := make([]byte, cipherMode.ivSize)
263	key := make([]byte, cipherMode.keySize)
264
265	generateKeyMaterial(iv, d.ivTag, kex)
266	generateKeyMaterial(key, d.keyTag, kex)
267
268	var macKey []byte
269	if !aeadCiphers[algs.Cipher] {
270		macMode := macModes[algs.MAC]
271		macKey = make([]byte, macMode.keySize)
272		generateKeyMaterial(macKey, d.macKeyTag, kex)
273	}
274
275	return cipherModes[algs.Cipher].create(key, iv, macKey, algs)
276}
277
278// generateKeyMaterial fills out with key material generated from tag, K, H
279// and sessionId, as specified in RFC 4253, section 7.2.
280func generateKeyMaterial(out, tag []byte, r *kexResult) {
281	var digestsSoFar []byte
282
283	h := r.Hash.New()
284	for len(out) > 0 {
285		h.Reset()
286		h.Write(r.K)
287		h.Write(r.H)
288
289		if len(digestsSoFar) == 0 {
290			h.Write(tag)
291			h.Write(r.SessionID)
292		} else {
293			h.Write(digestsSoFar)
294		}
295
296		digest := h.Sum(nil)
297		n := copy(out, digest)
298		out = out[n:]
299		if len(out) > 0 {
300			digestsSoFar = append(digestsSoFar, digest...)
301		}
302	}
303}
304
305const packageVersion = "SSH-2.0-Go"
306
307// Sends and receives a version line.  The versionLine string should
308// be US ASCII, start with "SSH-2.0-", and should not include a
309// newline. exchangeVersions returns the other side's version line.
310func exchangeVersions(rw io.ReadWriter, versionLine []byte) (them []byte, err error) {
311	// Contrary to the RFC, we do not ignore lines that don't
312	// start with "SSH-2.0-" to make the library usable with
313	// nonconforming servers.
314	for _, c := range versionLine {
315		// The spec disallows non US-ASCII chars, and
316		// specifically forbids null chars.
317		if c < 32 {
318			return nil, errors.New("ssh: junk character in version line")
319		}
320	}
321	if _, err = rw.Write(append(versionLine, '\r', '\n')); err != nil {
322		return
323	}
324
325	them, err = readVersion(rw)
326	return them, err
327}
328
329// maxVersionStringBytes is the maximum number of bytes that we'll
330// accept as a version string. RFC 4253 section 4.2 limits this at 255
331// chars
332const maxVersionStringBytes = 255
333
334// Read version string as specified by RFC 4253, section 4.2.
335func readVersion(r io.Reader) ([]byte, error) {
336	versionString := make([]byte, 0, 64)
337	var ok bool
338	var buf [1]byte
339
340	for length := 0; length < maxVersionStringBytes; length++ {
341		_, err := io.ReadFull(r, buf[:])
342		if err != nil {
343			return nil, err
344		}
345		// The RFC says that the version should be terminated with \r\n
346		// but several SSH servers actually only send a \n.
347		if buf[0] == '\n' {
348			if !bytes.HasPrefix(versionString, []byte("SSH-")) {
349				// RFC 4253 says we need to ignore all version string lines
350				// except the one containing the SSH version (provided that
351				// all the lines do not exceed 255 bytes in total).
352				versionString = versionString[:0]
353				continue
354			}
355			ok = true
356			break
357		}
358
359		// non ASCII chars are disallowed, but we are lenient,
360		// since Go doesn't use null-terminated strings.
361
362		// The RFC allows a comment after a space, however,
363		// all of it (version and comments) goes into the
364		// session hash.
365		versionString = append(versionString, buf[0])
366	}
367
368	if !ok {
369		return nil, errors.New("ssh: overflow reading version string")
370	}
371
372	// There might be a '\r' on the end which we should remove.
373	if len(versionString) > 0 && versionString[len(versionString)-1] == '\r' {
374		versionString = versionString[:len(versionString)-1]
375	}
376	return versionString, nil
377}