1// Copyright 2024 The Go Authors. All rights reserved.
  2// Use of this source code is governed by a BSD-style
  3// license that can be found in the LICENSE file.
  4
  5package ssh
  6
  7import (
  8	"crypto"
  9	"crypto/mlkem"
 10	"crypto/sha256"
 11	"errors"
 12	"fmt"
 13	"io"
 14
 15	"golang.org/x/crypto/curve25519"
 16)
 17
 18// mlkem768WithCurve25519sha256 implements the hybrid ML-KEM768 with
 19// curve25519-sha256 key exchange method, as described by
 20// draft-kampanakis-curdle-ssh-pq-ke-05 section 2.3.3.
 21type mlkem768WithCurve25519sha256 struct{}
 22
 23func (kex *mlkem768WithCurve25519sha256) Client(c packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) {
 24	var c25519kp curve25519KeyPair
 25	if err := c25519kp.generate(rand); err != nil {
 26		return nil, err
 27	}
 28
 29	seed := make([]byte, mlkem.SeedSize)
 30	if _, err := io.ReadFull(rand, seed); err != nil {
 31		return nil, err
 32	}
 33
 34	mlkemDk, err := mlkem.NewDecapsulationKey768(seed)
 35	if err != nil {
 36		return nil, err
 37	}
 38
 39	hybridKey := append(mlkemDk.EncapsulationKey().Bytes(), c25519kp.pub[:]...)
 40	if err := c.writePacket(Marshal(&kexECDHInitMsg{hybridKey})); err != nil {
 41		return nil, err
 42	}
 43
 44	packet, err := c.readPacket()
 45	if err != nil {
 46		return nil, err
 47	}
 48
 49	var reply kexECDHReplyMsg
 50	if err = Unmarshal(packet, &reply); err != nil {
 51		return nil, err
 52	}
 53
 54	if len(reply.EphemeralPubKey) != mlkem.CiphertextSize768+32 {
 55		return nil, errors.New("ssh: peer's mlkem768x25519 public value has wrong length")
 56	}
 57
 58	// Perform KEM decapsulate operation to obtain shared key from ML-KEM.
 59	mlkem768Secret, err := mlkemDk.Decapsulate(reply.EphemeralPubKey[:mlkem.CiphertextSize768])
 60	if err != nil {
 61		return nil, err
 62	}
 63
 64	// Complete Curve25519 ECDH to obtain its shared key.
 65	c25519Secret, err := curve25519.X25519(c25519kp.priv[:], reply.EphemeralPubKey[mlkem.CiphertextSize768:])
 66	if err != nil {
 67		return nil, fmt.Errorf("ssh: peer's mlkem768x25519 public value is not valid: %w", err)
 68	}
 69	// Compute actual shared key.
 70	h := sha256.New()
 71	h.Write(mlkem768Secret)
 72	h.Write(c25519Secret)
 73	secret := h.Sum(nil)
 74
 75	h.Reset()
 76	magics.write(h)
 77	writeString(h, reply.HostKey)
 78	writeString(h, hybridKey)
 79	writeString(h, reply.EphemeralPubKey)
 80
 81	K := make([]byte, stringLength(len(secret)))
 82	marshalString(K, secret)
 83	h.Write(K)
 84
 85	return &kexResult{
 86		H:         h.Sum(nil),
 87		K:         K,
 88		HostKey:   reply.HostKey,
 89		Signature: reply.Signature,
 90		Hash:      crypto.SHA256,
 91	}, nil
 92}
 93
 94func (kex *mlkem768WithCurve25519sha256) Server(c packetConn, rand io.Reader, magics *handshakeMagics, priv AlgorithmSigner, algo string) (*kexResult, error) {
 95	packet, err := c.readPacket()
 96	if err != nil {
 97		return nil, err
 98	}
 99
100	var kexInit kexECDHInitMsg
101	if err = Unmarshal(packet, &kexInit); err != nil {
102		return nil, err
103	}
104
105	if len(kexInit.ClientPubKey) != mlkem.EncapsulationKeySize768+32 {
106		return nil, errors.New("ssh: peer's ML-KEM768/curve25519 public value has wrong length")
107	}
108
109	encapsulationKey, err := mlkem.NewEncapsulationKey768(kexInit.ClientPubKey[:mlkem.EncapsulationKeySize768])
110	if err != nil {
111		return nil, fmt.Errorf("ssh: peer's ML-KEM768 encapsulation key is not valid: %w", err)
112	}
113	// Perform KEM encapsulate operation to obtain ciphertext and shared key.
114	mlkem768Secret, mlkem768Ciphertext := encapsulationKey.Encapsulate()
115
116	// Perform server side of Curve25519 ECDH to obtain server public value and
117	// shared key.
118	var c25519kp curve25519KeyPair
119	if err := c25519kp.generate(rand); err != nil {
120		return nil, err
121	}
122	c25519Secret, err := curve25519.X25519(c25519kp.priv[:], kexInit.ClientPubKey[mlkem.EncapsulationKeySize768:])
123	if err != nil {
124		return nil, fmt.Errorf("ssh: peer's ML-KEM768/curve25519 public value is not valid: %w", err)
125	}
126	hybridKey := append(mlkem768Ciphertext, c25519kp.pub[:]...)
127
128	// Compute actual shared key.
129	h := sha256.New()
130	h.Write(mlkem768Secret)
131	h.Write(c25519Secret)
132	secret := h.Sum(nil)
133
134	hostKeyBytes := priv.PublicKey().Marshal()
135
136	h.Reset()
137	magics.write(h)
138	writeString(h, hostKeyBytes)
139	writeString(h, kexInit.ClientPubKey)
140	writeString(h, hybridKey)
141
142	K := make([]byte, stringLength(len(secret)))
143	marshalString(K, secret)
144	h.Write(K)
145
146	H := h.Sum(nil)
147
148	sig, err := signAndMarshal(priv, rand, H, algo)
149	if err != nil {
150		return nil, err
151	}
152
153	reply := kexECDHReplyMsg{
154		EphemeralPubKey: hybridKey,
155		HostKey:         hostKeyBytes,
156		Signature:       sig,
157	}
158	if err := c.writePacket(Marshal(&reply)); err != nil {
159		return nil, err
160	}
161	return &kexResult{
162		H:         H,
163		K:         K,
164		HostKey:   hostKeyBytes,
165		Signature: sig,
166		Hash:      crypto.SHA256,
167	}, nil
168}