1// Copyright 2012 The Go Authors. All rights reserved.
  2// Use of this source code is governed by a BSD-style
  3// license that can be found in the LICENSE file.
  4
  5package agent
  6
  7import (
  8	"crypto/dsa"
  9	"crypto/ecdsa"
 10	"crypto/ed25519"
 11	"crypto/elliptic"
 12	"crypto/rsa"
 13	"encoding/binary"
 14	"errors"
 15	"fmt"
 16	"io"
 17	"log"
 18	"math/big"
 19
 20	"golang.org/x/crypto/ssh"
 21)
 22
 23// server wraps an Agent and uses it to implement the agent side of
 24// the SSH-agent, wire protocol.
 25type server struct {
 26	agent Agent
 27}
 28
 29func (s *server) processRequestBytes(reqData []byte) []byte {
 30	rep, err := s.processRequest(reqData)
 31	if err != nil {
 32		if err != errLocked {
 33			// TODO(hanwen): provide better logging interface?
 34			log.Printf("agent %d: %v", reqData[0], err)
 35		}
 36		return []byte{agentFailure}
 37	}
 38
 39	if rep == nil {
 40		return []byte{agentSuccess}
 41	}
 42
 43	return ssh.Marshal(rep)
 44}
 45
 46func marshalKey(k *Key) []byte {
 47	var record struct {
 48		Blob    []byte
 49		Comment string
 50	}
 51	record.Blob = k.Marshal()
 52	record.Comment = k.Comment
 53
 54	return ssh.Marshal(&record)
 55}
 56
 57// See [PROTOCOL.agent], section 2.5.1.
 58const agentV1IdentitiesAnswer = 2
 59
 60type agentV1IdentityMsg struct {
 61	Numkeys uint32 `sshtype:"2"`
 62}
 63
 64type agentRemoveIdentityMsg struct {
 65	KeyBlob []byte `sshtype:"18"`
 66}
 67
 68type agentLockMsg struct {
 69	Passphrase []byte `sshtype:"22"`
 70}
 71
 72type agentUnlockMsg struct {
 73	Passphrase []byte `sshtype:"23"`
 74}
 75
 76func (s *server) processRequest(data []byte) (interface{}, error) {
 77	switch data[0] {
 78	case agentRequestV1Identities:
 79		return &agentV1IdentityMsg{0}, nil
 80
 81	case agentRemoveAllV1Identities:
 82		return nil, nil
 83
 84	case agentRemoveIdentity:
 85		var req agentRemoveIdentityMsg
 86		if err := ssh.Unmarshal(data, &req); err != nil {
 87			return nil, err
 88		}
 89
 90		var wk wireKey
 91		if err := ssh.Unmarshal(req.KeyBlob, &wk); err != nil {
 92			return nil, err
 93		}
 94
 95		return nil, s.agent.Remove(&Key{Format: wk.Format, Blob: req.KeyBlob})
 96
 97	case agentRemoveAllIdentities:
 98		return nil, s.agent.RemoveAll()
 99
100	case agentLock:
101		var req agentLockMsg
102		if err := ssh.Unmarshal(data, &req); err != nil {
103			return nil, err
104		}
105
106		return nil, s.agent.Lock(req.Passphrase)
107
108	case agentUnlock:
109		var req agentUnlockMsg
110		if err := ssh.Unmarshal(data, &req); err != nil {
111			return nil, err
112		}
113		return nil, s.agent.Unlock(req.Passphrase)
114
115	case agentSignRequest:
116		var req signRequestAgentMsg
117		if err := ssh.Unmarshal(data, &req); err != nil {
118			return nil, err
119		}
120
121		var wk wireKey
122		if err := ssh.Unmarshal(req.KeyBlob, &wk); err != nil {
123			return nil, err
124		}
125
126		k := &Key{
127			Format: wk.Format,
128			Blob:   req.KeyBlob,
129		}
130
131		var sig *ssh.Signature
132		var err error
133		if extendedAgent, ok := s.agent.(ExtendedAgent); ok {
134			sig, err = extendedAgent.SignWithFlags(k, req.Data, SignatureFlags(req.Flags))
135		} else {
136			sig, err = s.agent.Sign(k, req.Data)
137		}
138
139		if err != nil {
140			return nil, err
141		}
142		return &signResponseAgentMsg{SigBlob: ssh.Marshal(sig)}, nil
143
144	case agentRequestIdentities:
145		keys, err := s.agent.List()
146		if err != nil {
147			return nil, err
148		}
149
150		rep := identitiesAnswerAgentMsg{
151			NumKeys: uint32(len(keys)),
152		}
153		for _, k := range keys {
154			rep.Keys = append(rep.Keys, marshalKey(k)...)
155		}
156		return rep, nil
157
158	case agentAddIDConstrained, agentAddIdentity:
159		return nil, s.insertIdentity(data)
160
161	case agentExtension:
162		// Return a stub object where the whole contents of the response gets marshaled.
163		var responseStub struct {
164			Rest []byte `ssh:"rest"`
165		}
166
167		if extendedAgent, ok := s.agent.(ExtendedAgent); !ok {
168			// If this agent doesn't implement extensions, [PROTOCOL.agent] section 4.7
169			// requires that we return a standard SSH_AGENT_FAILURE message.
170			responseStub.Rest = []byte{agentFailure}
171		} else {
172			var req extensionAgentMsg
173			if err := ssh.Unmarshal(data, &req); err != nil {
174				return nil, err
175			}
176			res, err := extendedAgent.Extension(req.ExtensionType, req.Contents)
177			if err != nil {
178				// If agent extensions are unsupported, return a standard SSH_AGENT_FAILURE
179				// message as required by [PROTOCOL.agent] section 4.7.
180				if err == ErrExtensionUnsupported {
181					responseStub.Rest = []byte{agentFailure}
182				} else {
183					// As the result of any other error processing an extension request,
184					// [PROTOCOL.agent] section 4.7 requires that we return a
185					// SSH_AGENT_EXTENSION_FAILURE code.
186					responseStub.Rest = []byte{agentExtensionFailure}
187				}
188			} else {
189				if len(res) == 0 {
190					return nil, nil
191				}
192				responseStub.Rest = res
193			}
194		}
195
196		return responseStub, nil
197	}
198
199	return nil, fmt.Errorf("unknown opcode %d", data[0])
200}
201
202func parseConstraints(constraints []byte) (lifetimeSecs uint32, confirmBeforeUse bool, extensions []ConstraintExtension, err error) {
203	for len(constraints) != 0 {
204		switch constraints[0] {
205		case agentConstrainLifetime:
206			if len(constraints) < 5 {
207				return 0, false, nil, io.ErrUnexpectedEOF
208			}
209			lifetimeSecs = binary.BigEndian.Uint32(constraints[1:5])
210			constraints = constraints[5:]
211		case agentConstrainConfirm:
212			confirmBeforeUse = true
213			constraints = constraints[1:]
214		case agentConstrainExtension, agentConstrainExtensionV00:
215			var msg constrainExtensionAgentMsg
216			if err = ssh.Unmarshal(constraints, &msg); err != nil {
217				return 0, false, nil, err
218			}
219			extensions = append(extensions, ConstraintExtension{
220				ExtensionName:    msg.ExtensionName,
221				ExtensionDetails: msg.ExtensionDetails,
222			})
223			constraints = msg.Rest
224		default:
225			return 0, false, nil, fmt.Errorf("unknown constraint type: %d", constraints[0])
226		}
227	}
228	return
229}
230
231func setConstraints(key *AddedKey, constraintBytes []byte) error {
232	lifetimeSecs, confirmBeforeUse, constraintExtensions, err := parseConstraints(constraintBytes)
233	if err != nil {
234		return err
235	}
236
237	key.LifetimeSecs = lifetimeSecs
238	key.ConfirmBeforeUse = confirmBeforeUse
239	key.ConstraintExtensions = constraintExtensions
240	return nil
241}
242
243func parseRSAKey(req []byte) (*AddedKey, error) {
244	var k rsaKeyMsg
245	if err := ssh.Unmarshal(req, &k); err != nil {
246		return nil, err
247	}
248	if k.E.BitLen() > 30 {
249		return nil, errors.New("agent: RSA public exponent too large")
250	}
251	priv := &rsa.PrivateKey{
252		PublicKey: rsa.PublicKey{
253			E: int(k.E.Int64()),
254			N: k.N,
255		},
256		D:      k.D,
257		Primes: []*big.Int{k.P, k.Q},
258	}
259	priv.Precompute()
260
261	addedKey := &AddedKey{PrivateKey: priv, Comment: k.Comments}
262	if err := setConstraints(addedKey, k.Constraints); err != nil {
263		return nil, err
264	}
265	return addedKey, nil
266}
267
268func parseEd25519Key(req []byte) (*AddedKey, error) {
269	var k ed25519KeyMsg
270	if err := ssh.Unmarshal(req, &k); err != nil {
271		return nil, err
272	}
273	priv := ed25519.PrivateKey(k.Priv)
274
275	addedKey := &AddedKey{PrivateKey: &priv, Comment: k.Comments}
276	if err := setConstraints(addedKey, k.Constraints); err != nil {
277		return nil, err
278	}
279	return addedKey, nil
280}
281
282func parseDSAKey(req []byte) (*AddedKey, error) {
283	var k dsaKeyMsg
284	if err := ssh.Unmarshal(req, &k); err != nil {
285		return nil, err
286	}
287	priv := &dsa.PrivateKey{
288		PublicKey: dsa.PublicKey{
289			Parameters: dsa.Parameters{
290				P: k.P,
291				Q: k.Q,
292				G: k.G,
293			},
294			Y: k.Y,
295		},
296		X: k.X,
297	}
298
299	addedKey := &AddedKey{PrivateKey: priv, Comment: k.Comments}
300	if err := setConstraints(addedKey, k.Constraints); err != nil {
301		return nil, err
302	}
303	return addedKey, nil
304}
305
306func unmarshalECDSA(curveName string, keyBytes []byte, privScalar *big.Int) (priv *ecdsa.PrivateKey, err error) {
307	priv = &ecdsa.PrivateKey{
308		D: privScalar,
309	}
310
311	switch curveName {
312	case "nistp256":
313		priv.Curve = elliptic.P256()
314	case "nistp384":
315		priv.Curve = elliptic.P384()
316	case "nistp521":
317		priv.Curve = elliptic.P521()
318	default:
319		return nil, fmt.Errorf("agent: unknown curve %q", curveName)
320	}
321
322	priv.X, priv.Y = elliptic.Unmarshal(priv.Curve, keyBytes)
323	if priv.X == nil || priv.Y == nil {
324		return nil, errors.New("agent: point not on curve")
325	}
326
327	return priv, nil
328}
329
330func parseEd25519Cert(req []byte) (*AddedKey, error) {
331	var k ed25519CertMsg
332	if err := ssh.Unmarshal(req, &k); err != nil {
333		return nil, err
334	}
335	pubKey, err := ssh.ParsePublicKey(k.CertBytes)
336	if err != nil {
337		return nil, err
338	}
339	priv := ed25519.PrivateKey(k.Priv)
340	cert, ok := pubKey.(*ssh.Certificate)
341	if !ok {
342		return nil, errors.New("agent: bad ED25519 certificate")
343	}
344
345	addedKey := &AddedKey{PrivateKey: &priv, Certificate: cert, Comment: k.Comments}
346	if err := setConstraints(addedKey, k.Constraints); err != nil {
347		return nil, err
348	}
349	return addedKey, nil
350}
351
352func parseECDSAKey(req []byte) (*AddedKey, error) {
353	var k ecdsaKeyMsg
354	if err := ssh.Unmarshal(req, &k); err != nil {
355		return nil, err
356	}
357
358	priv, err := unmarshalECDSA(k.Curve, k.KeyBytes, k.D)
359	if err != nil {
360		return nil, err
361	}
362
363	addedKey := &AddedKey{PrivateKey: priv, Comment: k.Comments}
364	if err := setConstraints(addedKey, k.Constraints); err != nil {
365		return nil, err
366	}
367	return addedKey, nil
368}
369
370func parseRSACert(req []byte) (*AddedKey, error) {
371	var k rsaCertMsg
372	if err := ssh.Unmarshal(req, &k); err != nil {
373		return nil, err
374	}
375
376	pubKey, err := ssh.ParsePublicKey(k.CertBytes)
377	if err != nil {
378		return nil, err
379	}
380
381	cert, ok := pubKey.(*ssh.Certificate)
382	if !ok {
383		return nil, errors.New("agent: bad RSA certificate")
384	}
385
386	// An RSA publickey as marshaled by rsaPublicKey.Marshal() in keys.go
387	var rsaPub struct {
388		Name string
389		E    *big.Int
390		N    *big.Int
391	}
392	if err := ssh.Unmarshal(cert.Key.Marshal(), &rsaPub); err != nil {
393		return nil, fmt.Errorf("agent: Unmarshal failed to parse public key: %v", err)
394	}
395
396	if rsaPub.E.BitLen() > 30 {
397		return nil, errors.New("agent: RSA public exponent too large")
398	}
399
400	priv := rsa.PrivateKey{
401		PublicKey: rsa.PublicKey{
402			E: int(rsaPub.E.Int64()),
403			N: rsaPub.N,
404		},
405		D:      k.D,
406		Primes: []*big.Int{k.Q, k.P},
407	}
408	priv.Precompute()
409
410	addedKey := &AddedKey{PrivateKey: &priv, Certificate: cert, Comment: k.Comments}
411	if err := setConstraints(addedKey, k.Constraints); err != nil {
412		return nil, err
413	}
414	return addedKey, nil
415}
416
417func parseDSACert(req []byte) (*AddedKey, error) {
418	var k dsaCertMsg
419	if err := ssh.Unmarshal(req, &k); err != nil {
420		return nil, err
421	}
422	pubKey, err := ssh.ParsePublicKey(k.CertBytes)
423	if err != nil {
424		return nil, err
425	}
426	cert, ok := pubKey.(*ssh.Certificate)
427	if !ok {
428		return nil, errors.New("agent: bad DSA certificate")
429	}
430
431	// A DSA publickey as marshaled by dsaPublicKey.Marshal() in keys.go
432	var w struct {
433		Name       string
434		P, Q, G, Y *big.Int
435	}
436	if err := ssh.Unmarshal(cert.Key.Marshal(), &w); err != nil {
437		return nil, fmt.Errorf("agent: Unmarshal failed to parse public key: %v", err)
438	}
439
440	priv := &dsa.PrivateKey{
441		PublicKey: dsa.PublicKey{
442			Parameters: dsa.Parameters{
443				P: w.P,
444				Q: w.Q,
445				G: w.G,
446			},
447			Y: w.Y,
448		},
449		X: k.X,
450	}
451
452	addedKey := &AddedKey{PrivateKey: priv, Certificate: cert, Comment: k.Comments}
453	if err := setConstraints(addedKey, k.Constraints); err != nil {
454		return nil, err
455	}
456	return addedKey, nil
457}
458
459func parseECDSACert(req []byte) (*AddedKey, error) {
460	var k ecdsaCertMsg
461	if err := ssh.Unmarshal(req, &k); err != nil {
462		return nil, err
463	}
464
465	pubKey, err := ssh.ParsePublicKey(k.CertBytes)
466	if err != nil {
467		return nil, err
468	}
469	cert, ok := pubKey.(*ssh.Certificate)
470	if !ok {
471		return nil, errors.New("agent: bad ECDSA certificate")
472	}
473
474	// An ECDSA publickey as marshaled by ecdsaPublicKey.Marshal() in keys.go
475	var ecdsaPub struct {
476		Name string
477		ID   string
478		Key  []byte
479	}
480	if err := ssh.Unmarshal(cert.Key.Marshal(), &ecdsaPub); err != nil {
481		return nil, err
482	}
483
484	priv, err := unmarshalECDSA(ecdsaPub.ID, ecdsaPub.Key, k.D)
485	if err != nil {
486		return nil, err
487	}
488
489	addedKey := &AddedKey{PrivateKey: priv, Certificate: cert, Comment: k.Comments}
490	if err := setConstraints(addedKey, k.Constraints); err != nil {
491		return nil, err
492	}
493	return addedKey, nil
494}
495
496func (s *server) insertIdentity(req []byte) error {
497	var record struct {
498		Type string `sshtype:"17|25"`
499		Rest []byte `ssh:"rest"`
500	}
501
502	if err := ssh.Unmarshal(req, &record); err != nil {
503		return err
504	}
505
506	var addedKey *AddedKey
507	var err error
508
509	switch record.Type {
510	case ssh.KeyAlgoRSA:
511		addedKey, err = parseRSAKey(req)
512	case ssh.InsecureKeyAlgoDSA:
513		addedKey, err = parseDSAKey(req)
514	case ssh.KeyAlgoECDSA256, ssh.KeyAlgoECDSA384, ssh.KeyAlgoECDSA521:
515		addedKey, err = parseECDSAKey(req)
516	case ssh.KeyAlgoED25519:
517		addedKey, err = parseEd25519Key(req)
518	case ssh.CertAlgoRSAv01:
519		addedKey, err = parseRSACert(req)
520	case ssh.InsecureCertAlgoDSAv01:
521		addedKey, err = parseDSACert(req)
522	case ssh.CertAlgoECDSA256v01, ssh.CertAlgoECDSA384v01, ssh.CertAlgoECDSA521v01:
523		addedKey, err = parseECDSACert(req)
524	case ssh.CertAlgoED25519v01:
525		addedKey, err = parseEd25519Cert(req)
526	default:
527		return fmt.Errorf("agent: not implemented: %q", record.Type)
528	}
529
530	if err != nil {
531		return err
532	}
533	return s.agent.Add(*addedKey)
534}
535
536// ServeAgent serves the agent protocol on the given connection. It
537// returns when an I/O error occurs.
538func ServeAgent(agent Agent, c io.ReadWriter) error {
539	s := &server{agent}
540
541	var length [4]byte
542	for {
543		if _, err := io.ReadFull(c, length[:]); err != nil {
544			return err
545		}
546		l := binary.BigEndian.Uint32(length[:])
547		if l == 0 {
548			return fmt.Errorf("agent: request size is 0")
549		}
550		if l > maxAgentResponseBytes {
551			// We also cap requests.
552			return fmt.Errorf("agent: request too large: %d", l)
553		}
554
555		req := make([]byte, l)
556		if _, err := io.ReadFull(c, req); err != nil {
557			return err
558		}
559
560		repData := s.processRequestBytes(req)
561		if len(repData) > maxAgentResponseBytes {
562			return fmt.Errorf("agent: reply too large: %d bytes", len(repData))
563		}
564
565		binary.BigEndian.PutUint32(length[:], uint32(len(repData)))
566		if _, err := c.Write(length[:]); err != nil {
567			return err
568		}
569		if _, err := c.Write(repData); err != nil {
570			return err
571		}
572	}
573}