1// Copyright 2024 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package ssh
6
7import (
8 "crypto"
9 "crypto/mlkem"
10 "crypto/sha256"
11 "errors"
12 "fmt"
13 "io"
14
15 "golang.org/x/crypto/curve25519"
16)
17
18// mlkem768WithCurve25519sha256 implements the hybrid ML-KEM768 with
19// curve25519-sha256 key exchange method, as described by
20// draft-kampanakis-curdle-ssh-pq-ke-05 section 2.3.3.
21type mlkem768WithCurve25519sha256 struct{}
22
23func (kex *mlkem768WithCurve25519sha256) Client(c packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) {
24 var c25519kp curve25519KeyPair
25 if err := c25519kp.generate(rand); err != nil {
26 return nil, err
27 }
28
29 seed := make([]byte, mlkem.SeedSize)
30 if _, err := io.ReadFull(rand, seed); err != nil {
31 return nil, err
32 }
33
34 mlkemDk, err := mlkem.NewDecapsulationKey768(seed)
35 if err != nil {
36 return nil, err
37 }
38
39 hybridKey := append(mlkemDk.EncapsulationKey().Bytes(), c25519kp.pub[:]...)
40 if err := c.writePacket(Marshal(&kexECDHInitMsg{hybridKey})); err != nil {
41 return nil, err
42 }
43
44 packet, err := c.readPacket()
45 if err != nil {
46 return nil, err
47 }
48
49 var reply kexECDHReplyMsg
50 if err = Unmarshal(packet, &reply); err != nil {
51 return nil, err
52 }
53
54 if len(reply.EphemeralPubKey) != mlkem.CiphertextSize768+32 {
55 return nil, errors.New("ssh: peer's mlkem768x25519 public value has wrong length")
56 }
57
58 // Perform KEM decapsulate operation to obtain shared key from ML-KEM.
59 mlkem768Secret, err := mlkemDk.Decapsulate(reply.EphemeralPubKey[:mlkem.CiphertextSize768])
60 if err != nil {
61 return nil, err
62 }
63
64 // Complete Curve25519 ECDH to obtain its shared key.
65 c25519Secret, err := curve25519.X25519(c25519kp.priv[:], reply.EphemeralPubKey[mlkem.CiphertextSize768:])
66 if err != nil {
67 return nil, fmt.Errorf("ssh: peer's mlkem768x25519 public value is not valid: %w", err)
68 }
69 // Compute actual shared key.
70 h := sha256.New()
71 h.Write(mlkem768Secret)
72 h.Write(c25519Secret)
73 secret := h.Sum(nil)
74
75 h.Reset()
76 magics.write(h)
77 writeString(h, reply.HostKey)
78 writeString(h, hybridKey)
79 writeString(h, reply.EphemeralPubKey)
80
81 K := make([]byte, stringLength(len(secret)))
82 marshalString(K, secret)
83 h.Write(K)
84
85 return &kexResult{
86 H: h.Sum(nil),
87 K: K,
88 HostKey: reply.HostKey,
89 Signature: reply.Signature,
90 Hash: crypto.SHA256,
91 }, nil
92}
93
94func (kex *mlkem768WithCurve25519sha256) Server(c packetConn, rand io.Reader, magics *handshakeMagics, priv AlgorithmSigner, algo string) (*kexResult, error) {
95 packet, err := c.readPacket()
96 if err != nil {
97 return nil, err
98 }
99
100 var kexInit kexECDHInitMsg
101 if err = Unmarshal(packet, &kexInit); err != nil {
102 return nil, err
103 }
104
105 if len(kexInit.ClientPubKey) != mlkem.EncapsulationKeySize768+32 {
106 return nil, errors.New("ssh: peer's ML-KEM768/curve25519 public value has wrong length")
107 }
108
109 encapsulationKey, err := mlkem.NewEncapsulationKey768(kexInit.ClientPubKey[:mlkem.EncapsulationKeySize768])
110 if err != nil {
111 return nil, fmt.Errorf("ssh: peer's ML-KEM768 encapsulation key is not valid: %w", err)
112 }
113 // Perform KEM encapsulate operation to obtain ciphertext and shared key.
114 mlkem768Secret, mlkem768Ciphertext := encapsulationKey.Encapsulate()
115
116 // Perform server side of Curve25519 ECDH to obtain server public value and
117 // shared key.
118 var c25519kp curve25519KeyPair
119 if err := c25519kp.generate(rand); err != nil {
120 return nil, err
121 }
122 c25519Secret, err := curve25519.X25519(c25519kp.priv[:], kexInit.ClientPubKey[mlkem.EncapsulationKeySize768:])
123 if err != nil {
124 return nil, fmt.Errorf("ssh: peer's ML-KEM768/curve25519 public value is not valid: %w", err)
125 }
126 hybridKey := append(mlkem768Ciphertext, c25519kp.pub[:]...)
127
128 // Compute actual shared key.
129 h := sha256.New()
130 h.Write(mlkem768Secret)
131 h.Write(c25519Secret)
132 secret := h.Sum(nil)
133
134 hostKeyBytes := priv.PublicKey().Marshal()
135
136 h.Reset()
137 magics.write(h)
138 writeString(h, hostKeyBytes)
139 writeString(h, kexInit.ClientPubKey)
140 writeString(h, hybridKey)
141
142 K := make([]byte, stringLength(len(secret)))
143 marshalString(K, secret)
144 h.Write(K)
145
146 H := h.Sum(nil)
147
148 sig, err := signAndMarshal(priv, rand, H, algo)
149 if err != nil {
150 return nil, err
151 }
152
153 reply := kexECDHReplyMsg{
154 EphemeralPubKey: hybridKey,
155 HostKey: hostKeyBytes,
156 Signature: sig,
157 }
158 if err := c.writePacket(Marshal(&reply)); err != nil {
159 return nil, err
160 }
161 return &kexResult{
162 H: H,
163 K: K,
164 HostKey: hostKeyBytes,
165 Signature: sig,
166 Hash: crypto.SHA256,
167 }, nil
168}