1// Copyright 2012 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package agent
6
7import (
8 "crypto/dsa"
9 "crypto/ecdsa"
10 "crypto/ed25519"
11 "crypto/elliptic"
12 "crypto/rsa"
13 "encoding/binary"
14 "errors"
15 "fmt"
16 "io"
17 "log"
18 "math/big"
19
20 "golang.org/x/crypto/ssh"
21)
22
23// server wraps an Agent and uses it to implement the agent side of
24// the SSH-agent, wire protocol.
25type server struct {
26 agent Agent
27}
28
29func (s *server) processRequestBytes(reqData []byte) []byte {
30 rep, err := s.processRequest(reqData)
31 if err != nil {
32 if err != errLocked {
33 // TODO(hanwen): provide better logging interface?
34 log.Printf("agent %d: %v", reqData[0], err)
35 }
36 return []byte{agentFailure}
37 }
38
39 if rep == nil {
40 return []byte{agentSuccess}
41 }
42
43 return ssh.Marshal(rep)
44}
45
46func marshalKey(k *Key) []byte {
47 var record struct {
48 Blob []byte
49 Comment string
50 }
51 record.Blob = k.Marshal()
52 record.Comment = k.Comment
53
54 return ssh.Marshal(&record)
55}
56
57// See [PROTOCOL.agent], section 2.5.1.
58const agentV1IdentitiesAnswer = 2
59
60type agentV1IdentityMsg struct {
61 Numkeys uint32 `sshtype:"2"`
62}
63
64type agentRemoveIdentityMsg struct {
65 KeyBlob []byte `sshtype:"18"`
66}
67
68type agentLockMsg struct {
69 Passphrase []byte `sshtype:"22"`
70}
71
72type agentUnlockMsg struct {
73 Passphrase []byte `sshtype:"23"`
74}
75
76func (s *server) processRequest(data []byte) (interface{}, error) {
77 switch data[0] {
78 case agentRequestV1Identities:
79 return &agentV1IdentityMsg{0}, nil
80
81 case agentRemoveAllV1Identities:
82 return nil, nil
83
84 case agentRemoveIdentity:
85 var req agentRemoveIdentityMsg
86 if err := ssh.Unmarshal(data, &req); err != nil {
87 return nil, err
88 }
89
90 var wk wireKey
91 if err := ssh.Unmarshal(req.KeyBlob, &wk); err != nil {
92 return nil, err
93 }
94
95 return nil, s.agent.Remove(&Key{Format: wk.Format, Blob: req.KeyBlob})
96
97 case agentRemoveAllIdentities:
98 return nil, s.agent.RemoveAll()
99
100 case agentLock:
101 var req agentLockMsg
102 if err := ssh.Unmarshal(data, &req); err != nil {
103 return nil, err
104 }
105
106 return nil, s.agent.Lock(req.Passphrase)
107
108 case agentUnlock:
109 var req agentUnlockMsg
110 if err := ssh.Unmarshal(data, &req); err != nil {
111 return nil, err
112 }
113 return nil, s.agent.Unlock(req.Passphrase)
114
115 case agentSignRequest:
116 var req signRequestAgentMsg
117 if err := ssh.Unmarshal(data, &req); err != nil {
118 return nil, err
119 }
120
121 var wk wireKey
122 if err := ssh.Unmarshal(req.KeyBlob, &wk); err != nil {
123 return nil, err
124 }
125
126 k := &Key{
127 Format: wk.Format,
128 Blob: req.KeyBlob,
129 }
130
131 var sig *ssh.Signature
132 var err error
133 if extendedAgent, ok := s.agent.(ExtendedAgent); ok {
134 sig, err = extendedAgent.SignWithFlags(k, req.Data, SignatureFlags(req.Flags))
135 } else {
136 sig, err = s.agent.Sign(k, req.Data)
137 }
138
139 if err != nil {
140 return nil, err
141 }
142 return &signResponseAgentMsg{SigBlob: ssh.Marshal(sig)}, nil
143
144 case agentRequestIdentities:
145 keys, err := s.agent.List()
146 if err != nil {
147 return nil, err
148 }
149
150 rep := identitiesAnswerAgentMsg{
151 NumKeys: uint32(len(keys)),
152 }
153 for _, k := range keys {
154 rep.Keys = append(rep.Keys, marshalKey(k)...)
155 }
156 return rep, nil
157
158 case agentAddIDConstrained, agentAddIdentity:
159 return nil, s.insertIdentity(data)
160
161 case agentExtension:
162 // Return a stub object where the whole contents of the response gets marshaled.
163 var responseStub struct {
164 Rest []byte `ssh:"rest"`
165 }
166
167 if extendedAgent, ok := s.agent.(ExtendedAgent); !ok {
168 // If this agent doesn't implement extensions, [PROTOCOL.agent] section 4.7
169 // requires that we return a standard SSH_AGENT_FAILURE message.
170 responseStub.Rest = []byte{agentFailure}
171 } else {
172 var req extensionAgentMsg
173 if err := ssh.Unmarshal(data, &req); err != nil {
174 return nil, err
175 }
176 res, err := extendedAgent.Extension(req.ExtensionType, req.Contents)
177 if err != nil {
178 // If agent extensions are unsupported, return a standard SSH_AGENT_FAILURE
179 // message as required by [PROTOCOL.agent] section 4.7.
180 if err == ErrExtensionUnsupported {
181 responseStub.Rest = []byte{agentFailure}
182 } else {
183 // As the result of any other error processing an extension request,
184 // [PROTOCOL.agent] section 4.7 requires that we return a
185 // SSH_AGENT_EXTENSION_FAILURE code.
186 responseStub.Rest = []byte{agentExtensionFailure}
187 }
188 } else {
189 if len(res) == 0 {
190 return nil, nil
191 }
192 responseStub.Rest = res
193 }
194 }
195
196 return responseStub, nil
197 }
198
199 return nil, fmt.Errorf("unknown opcode %d", data[0])
200}
201
202func parseConstraints(constraints []byte) (lifetimeSecs uint32, confirmBeforeUse bool, extensions []ConstraintExtension, err error) {
203 for len(constraints) != 0 {
204 switch constraints[0] {
205 case agentConstrainLifetime:
206 if len(constraints) < 5 {
207 return 0, false, nil, io.ErrUnexpectedEOF
208 }
209 lifetimeSecs = binary.BigEndian.Uint32(constraints[1:5])
210 constraints = constraints[5:]
211 case agentConstrainConfirm:
212 confirmBeforeUse = true
213 constraints = constraints[1:]
214 case agentConstrainExtension, agentConstrainExtensionV00:
215 var msg constrainExtensionAgentMsg
216 if err = ssh.Unmarshal(constraints, &msg); err != nil {
217 return 0, false, nil, err
218 }
219 extensions = append(extensions, ConstraintExtension{
220 ExtensionName: msg.ExtensionName,
221 ExtensionDetails: msg.ExtensionDetails,
222 })
223 constraints = msg.Rest
224 default:
225 return 0, false, nil, fmt.Errorf("unknown constraint type: %d", constraints[0])
226 }
227 }
228 return
229}
230
231func setConstraints(key *AddedKey, constraintBytes []byte) error {
232 lifetimeSecs, confirmBeforeUse, constraintExtensions, err := parseConstraints(constraintBytes)
233 if err != nil {
234 return err
235 }
236
237 key.LifetimeSecs = lifetimeSecs
238 key.ConfirmBeforeUse = confirmBeforeUse
239 key.ConstraintExtensions = constraintExtensions
240 return nil
241}
242
243func parseRSAKey(req []byte) (*AddedKey, error) {
244 var k rsaKeyMsg
245 if err := ssh.Unmarshal(req, &k); err != nil {
246 return nil, err
247 }
248 if k.E.BitLen() > 30 {
249 return nil, errors.New("agent: RSA public exponent too large")
250 }
251 priv := &rsa.PrivateKey{
252 PublicKey: rsa.PublicKey{
253 E: int(k.E.Int64()),
254 N: k.N,
255 },
256 D: k.D,
257 Primes: []*big.Int{k.P, k.Q},
258 }
259 priv.Precompute()
260
261 addedKey := &AddedKey{PrivateKey: priv, Comment: k.Comments}
262 if err := setConstraints(addedKey, k.Constraints); err != nil {
263 return nil, err
264 }
265 return addedKey, nil
266}
267
268func parseEd25519Key(req []byte) (*AddedKey, error) {
269 var k ed25519KeyMsg
270 if err := ssh.Unmarshal(req, &k); err != nil {
271 return nil, err
272 }
273 priv := ed25519.PrivateKey(k.Priv)
274
275 addedKey := &AddedKey{PrivateKey: &priv, Comment: k.Comments}
276 if err := setConstraints(addedKey, k.Constraints); err != nil {
277 return nil, err
278 }
279 return addedKey, nil
280}
281
282func parseDSAKey(req []byte) (*AddedKey, error) {
283 var k dsaKeyMsg
284 if err := ssh.Unmarshal(req, &k); err != nil {
285 return nil, err
286 }
287 priv := &dsa.PrivateKey{
288 PublicKey: dsa.PublicKey{
289 Parameters: dsa.Parameters{
290 P: k.P,
291 Q: k.Q,
292 G: k.G,
293 },
294 Y: k.Y,
295 },
296 X: k.X,
297 }
298
299 addedKey := &AddedKey{PrivateKey: priv, Comment: k.Comments}
300 if err := setConstraints(addedKey, k.Constraints); err != nil {
301 return nil, err
302 }
303 return addedKey, nil
304}
305
306func unmarshalECDSA(curveName string, keyBytes []byte, privScalar *big.Int) (priv *ecdsa.PrivateKey, err error) {
307 priv = &ecdsa.PrivateKey{
308 D: privScalar,
309 }
310
311 switch curveName {
312 case "nistp256":
313 priv.Curve = elliptic.P256()
314 case "nistp384":
315 priv.Curve = elliptic.P384()
316 case "nistp521":
317 priv.Curve = elliptic.P521()
318 default:
319 return nil, fmt.Errorf("agent: unknown curve %q", curveName)
320 }
321
322 priv.X, priv.Y = elliptic.Unmarshal(priv.Curve, keyBytes)
323 if priv.X == nil || priv.Y == nil {
324 return nil, errors.New("agent: point not on curve")
325 }
326
327 return priv, nil
328}
329
330func parseEd25519Cert(req []byte) (*AddedKey, error) {
331 var k ed25519CertMsg
332 if err := ssh.Unmarshal(req, &k); err != nil {
333 return nil, err
334 }
335 pubKey, err := ssh.ParsePublicKey(k.CertBytes)
336 if err != nil {
337 return nil, err
338 }
339 priv := ed25519.PrivateKey(k.Priv)
340 cert, ok := pubKey.(*ssh.Certificate)
341 if !ok {
342 return nil, errors.New("agent: bad ED25519 certificate")
343 }
344
345 addedKey := &AddedKey{PrivateKey: &priv, Certificate: cert, Comment: k.Comments}
346 if err := setConstraints(addedKey, k.Constraints); err != nil {
347 return nil, err
348 }
349 return addedKey, nil
350}
351
352func parseECDSAKey(req []byte) (*AddedKey, error) {
353 var k ecdsaKeyMsg
354 if err := ssh.Unmarshal(req, &k); err != nil {
355 return nil, err
356 }
357
358 priv, err := unmarshalECDSA(k.Curve, k.KeyBytes, k.D)
359 if err != nil {
360 return nil, err
361 }
362
363 addedKey := &AddedKey{PrivateKey: priv, Comment: k.Comments}
364 if err := setConstraints(addedKey, k.Constraints); err != nil {
365 return nil, err
366 }
367 return addedKey, nil
368}
369
370func parseRSACert(req []byte) (*AddedKey, error) {
371 var k rsaCertMsg
372 if err := ssh.Unmarshal(req, &k); err != nil {
373 return nil, err
374 }
375
376 pubKey, err := ssh.ParsePublicKey(k.CertBytes)
377 if err != nil {
378 return nil, err
379 }
380
381 cert, ok := pubKey.(*ssh.Certificate)
382 if !ok {
383 return nil, errors.New("agent: bad RSA certificate")
384 }
385
386 // An RSA publickey as marshaled by rsaPublicKey.Marshal() in keys.go
387 var rsaPub struct {
388 Name string
389 E *big.Int
390 N *big.Int
391 }
392 if err := ssh.Unmarshal(cert.Key.Marshal(), &rsaPub); err != nil {
393 return nil, fmt.Errorf("agent: Unmarshal failed to parse public key: %v", err)
394 }
395
396 if rsaPub.E.BitLen() > 30 {
397 return nil, errors.New("agent: RSA public exponent too large")
398 }
399
400 priv := rsa.PrivateKey{
401 PublicKey: rsa.PublicKey{
402 E: int(rsaPub.E.Int64()),
403 N: rsaPub.N,
404 },
405 D: k.D,
406 Primes: []*big.Int{k.Q, k.P},
407 }
408 priv.Precompute()
409
410 addedKey := &AddedKey{PrivateKey: &priv, Certificate: cert, Comment: k.Comments}
411 if err := setConstraints(addedKey, k.Constraints); err != nil {
412 return nil, err
413 }
414 return addedKey, nil
415}
416
417func parseDSACert(req []byte) (*AddedKey, error) {
418 var k dsaCertMsg
419 if err := ssh.Unmarshal(req, &k); err != nil {
420 return nil, err
421 }
422 pubKey, err := ssh.ParsePublicKey(k.CertBytes)
423 if err != nil {
424 return nil, err
425 }
426 cert, ok := pubKey.(*ssh.Certificate)
427 if !ok {
428 return nil, errors.New("agent: bad DSA certificate")
429 }
430
431 // A DSA publickey as marshaled by dsaPublicKey.Marshal() in keys.go
432 var w struct {
433 Name string
434 P, Q, G, Y *big.Int
435 }
436 if err := ssh.Unmarshal(cert.Key.Marshal(), &w); err != nil {
437 return nil, fmt.Errorf("agent: Unmarshal failed to parse public key: %v", err)
438 }
439
440 priv := &dsa.PrivateKey{
441 PublicKey: dsa.PublicKey{
442 Parameters: dsa.Parameters{
443 P: w.P,
444 Q: w.Q,
445 G: w.G,
446 },
447 Y: w.Y,
448 },
449 X: k.X,
450 }
451
452 addedKey := &AddedKey{PrivateKey: priv, Certificate: cert, Comment: k.Comments}
453 if err := setConstraints(addedKey, k.Constraints); err != nil {
454 return nil, err
455 }
456 return addedKey, nil
457}
458
459func parseECDSACert(req []byte) (*AddedKey, error) {
460 var k ecdsaCertMsg
461 if err := ssh.Unmarshal(req, &k); err != nil {
462 return nil, err
463 }
464
465 pubKey, err := ssh.ParsePublicKey(k.CertBytes)
466 if err != nil {
467 return nil, err
468 }
469 cert, ok := pubKey.(*ssh.Certificate)
470 if !ok {
471 return nil, errors.New("agent: bad ECDSA certificate")
472 }
473
474 // An ECDSA publickey as marshaled by ecdsaPublicKey.Marshal() in keys.go
475 var ecdsaPub struct {
476 Name string
477 ID string
478 Key []byte
479 }
480 if err := ssh.Unmarshal(cert.Key.Marshal(), &ecdsaPub); err != nil {
481 return nil, err
482 }
483
484 priv, err := unmarshalECDSA(ecdsaPub.ID, ecdsaPub.Key, k.D)
485 if err != nil {
486 return nil, err
487 }
488
489 addedKey := &AddedKey{PrivateKey: priv, Certificate: cert, Comment: k.Comments}
490 if err := setConstraints(addedKey, k.Constraints); err != nil {
491 return nil, err
492 }
493 return addedKey, nil
494}
495
496func (s *server) insertIdentity(req []byte) error {
497 var record struct {
498 Type string `sshtype:"17|25"`
499 Rest []byte `ssh:"rest"`
500 }
501
502 if err := ssh.Unmarshal(req, &record); err != nil {
503 return err
504 }
505
506 var addedKey *AddedKey
507 var err error
508
509 switch record.Type {
510 case ssh.KeyAlgoRSA:
511 addedKey, err = parseRSAKey(req)
512 case ssh.InsecureKeyAlgoDSA:
513 addedKey, err = parseDSAKey(req)
514 case ssh.KeyAlgoECDSA256, ssh.KeyAlgoECDSA384, ssh.KeyAlgoECDSA521:
515 addedKey, err = parseECDSAKey(req)
516 case ssh.KeyAlgoED25519:
517 addedKey, err = parseEd25519Key(req)
518 case ssh.CertAlgoRSAv01:
519 addedKey, err = parseRSACert(req)
520 case ssh.InsecureCertAlgoDSAv01:
521 addedKey, err = parseDSACert(req)
522 case ssh.CertAlgoECDSA256v01, ssh.CertAlgoECDSA384v01, ssh.CertAlgoECDSA521v01:
523 addedKey, err = parseECDSACert(req)
524 case ssh.CertAlgoED25519v01:
525 addedKey, err = parseEd25519Cert(req)
526 default:
527 return fmt.Errorf("agent: not implemented: %q", record.Type)
528 }
529
530 if err != nil {
531 return err
532 }
533 return s.agent.Add(*addedKey)
534}
535
536// ServeAgent serves the agent protocol on the given connection. It
537// returns when an I/O error occurs.
538func ServeAgent(agent Agent, c io.ReadWriter) error {
539 s := &server{agent}
540
541 var length [4]byte
542 for {
543 if _, err := io.ReadFull(c, length[:]); err != nil {
544 return err
545 }
546 l := binary.BigEndian.Uint32(length[:])
547 if l == 0 {
548 return fmt.Errorf("agent: request size is 0")
549 }
550 if l > maxAgentResponseBytes {
551 // We also cap requests.
552 return fmt.Errorf("agent: request too large: %d", l)
553 }
554
555 req := make([]byte, l)
556 if _, err := io.ReadFull(c, req); err != nil {
557 return err
558 }
559
560 repData := s.processRequestBytes(req)
561 if len(repData) > maxAgentResponseBytes {
562 return fmt.Errorf("agent: reply too large: %d bytes", len(repData))
563 }
564
565 binary.BigEndian.PutUint32(length[:], uint32(len(repData)))
566 if _, err := c.Write(length[:]); err != nil {
567 return err
568 }
569 if _, err := c.Write(repData); err != nil {
570 return err
571 }
572 }
573}