1#include "api.h"
   2#include "./parser.h"
   3#include <stdint.h>
   4
   5#ifdef TREE_SITTER_FEATURE_WASM
   6
   7#include <wasmtime.h>
   8#include <wasm.h>
   9#include <string.h>
  10#include "./alloc.h"
  11#include "./array.h"
  12#include "./atomic.h"
  13#include "./language.h"
  14#include "./lexer.h"
  15#include "./wasm_store.h"
  16#include "./wasm/wasm-stdlib.h"
  17
  18#define array_len(a) (sizeof(a) / sizeof(a[0]))
  19
  20// The following symbols from the C and C++ standard libraries are available
  21// for external scanners to use.
  22const char *STDLIB_SYMBOLS[] = {
  23  #include "./stdlib-symbols.txt"
  24};
  25
  26// The contents of the `dylink.0` custom section of a wasm module,
  27// as specified by the current WebAssembly dynamic linking ABI proposal.
  28typedef struct {
  29  uint32_t memory_size;
  30  uint32_t memory_align;
  31  uint32_t table_size;
  32  uint32_t table_align;
  33} WasmDylinkInfo;
  34
  35// WasmLanguageId - A pointer used to identify a language. This language id is
  36// reference-counted, so that its ownership can be shared between the language
  37// itself and the instances of the language that are held in wasm stores.
  38typedef struct {
  39  volatile uint32_t ref_count;
  40  volatile uint32_t is_language_deleted;
  41} WasmLanguageId;
  42
  43// LanguageWasmModule - Additional data associated with a wasm-backed
  44// `TSLanguage`. This data is read-only and does not reference a particular
  45// wasm store, so it can be shared by all users of a `TSLanguage`. A pointer to
  46// this is stored on the language itself.
  47typedef struct {
  48  volatile uint32_t ref_count;
  49  WasmLanguageId *language_id;
  50  wasmtime_module_t *module;
  51  const char *name;
  52  char *symbol_name_buffer;
  53  char *field_name_buffer;
  54  WasmDylinkInfo dylink_info;
  55} LanguageWasmModule;
  56
  57// LanguageWasmInstance - Additional data associated with an instantiation of
  58// a `TSLanguage` in a particular wasm store. The wasm store holds one of
  59// these structs for each language that it has instantiated.
  60typedef struct {
  61  WasmLanguageId *language_id;
  62  wasmtime_instance_t instance;
  63  int32_t external_states_address;
  64  int32_t lex_main_fn_index;
  65  int32_t lex_keyword_fn_index;
  66  int32_t scanner_create_fn_index;
  67  int32_t scanner_destroy_fn_index;
  68  int32_t scanner_serialize_fn_index;
  69  int32_t scanner_deserialize_fn_index;
  70  int32_t scanner_scan_fn_index;
  71} LanguageWasmInstance;
  72
  73typedef struct {
  74  uint32_t reset_heap;
  75  uint32_t proc_exit;
  76  uint32_t abort;
  77  uint32_t assert_fail;
  78  uint32_t notify_memory_growth;
  79  uint32_t debug_message;
  80  uint32_t at_exit;
  81  uint32_t args_get;
  82  uint32_t args_sizes_get;
  83} BuiltinFunctionIndices;
  84
  85// TSWasmStore - A struct that allows a given `Parser` to use wasm-backed
  86// languages. This struct is mutable, and can only be used by one parser at a
  87// time.
  88struct TSWasmStore {
  89  wasm_engine_t *engine;
  90  wasmtime_store_t *store;
  91  wasmtime_table_t function_table;
  92  wasmtime_memory_t memory;
  93  TSLexer *current_lexer;
  94  LanguageWasmInstance *current_instance;
  95  Array(LanguageWasmInstance) language_instances;
  96  uint32_t current_memory_offset;
  97  uint32_t current_function_table_offset;
  98  uint32_t *stdlib_fn_indices;
  99  BuiltinFunctionIndices builtin_fn_indices;
 100  wasmtime_global_t stack_pointer_global;
 101  wasm_globaltype_t *const_i32_type;
 102  bool has_error;
 103  uint32_t lexer_address;
 104  uint32_t serialization_buffer_address;
 105};
 106
 107typedef Array(char) StringData;
 108
 109// LanguageInWasmMemory - The memory layout of a `TSLanguage` when compiled to
 110// wasm32. This is used to copy static language data out of the wasm memory.
 111typedef struct {
 112  uint32_t version;
 113  uint32_t symbol_count;
 114  uint32_t alias_count;
 115  uint32_t token_count;
 116  uint32_t external_token_count;
 117  uint32_t state_count;
 118  uint32_t large_state_count;
 119  uint32_t production_id_count;
 120  uint32_t field_count;
 121  uint16_t max_alias_sequence_length;
 122  int32_t parse_table;
 123  int32_t small_parse_table;
 124  int32_t small_parse_table_map;
 125  int32_t parse_actions;
 126  int32_t symbol_names;
 127  int32_t field_names;
 128  int32_t field_map_slices;
 129  int32_t field_map_entries;
 130  int32_t symbol_metadata;
 131  int32_t public_symbol_map;
 132  int32_t alias_map;
 133  int32_t alias_sequences;
 134  int32_t lex_modes;
 135  int32_t lex_fn;
 136  int32_t keyword_lex_fn;
 137  TSSymbol keyword_capture_token;
 138  struct {
 139    int32_t states;
 140    int32_t symbol_map;
 141    int32_t create;
 142    int32_t destroy;
 143    int32_t scan;
 144    int32_t serialize;
 145    int32_t deserialize;
 146  } external_scanner;
 147  int32_t primary_state_ids;
 148} LanguageInWasmMemory;
 149
 150// LexerInWasmMemory - The memory layout of a `TSLexer` when compiled to wasm32.
 151// This is used to copy mutable lexing state in and out of the wasm memory.
 152typedef struct {
 153  int32_t lookahead;
 154  TSSymbol result_symbol;
 155  int32_t advance;
 156  int32_t mark_end;
 157  int32_t get_column;
 158  int32_t is_at_included_range_start;
 159  int32_t eof;
 160} LexerInWasmMemory;
 161
 162static volatile uint32_t NEXT_LANGUAGE_ID;
 163
 164// Linear memory layout:
 165// [ <-- stack | stdlib statics | lexer | serialization_buffer | language statics --> | heap --> ]
 166#define MAX_MEMORY_SIZE (128 * 1024 * 1024 / MEMORY_PAGE_SIZE)
 167
 168/************************
 169 * WasmDylinkMemoryInfo
 170 ***********************/
 171
 172static uint8_t read_u8(const uint8_t **p, const uint8_t *end) {
 173  return *(*p)++;
 174}
 175
 176static inline uint64_t read_uleb128(const uint8_t **p, const uint8_t *end) {
 177  uint64_t value = 0;
 178  unsigned shift = 0;
 179  do {
 180    if (*p == end)  return UINT64_MAX;
 181    value += (uint64_t)(**p & 0x7f) << shift;
 182    shift += 7;
 183  } while (*((*p)++) >= 128);
 184  return value;
 185}
 186
 187static bool wasm_dylink_info__parse(
 188  const uint8_t *bytes,
 189  size_t length,
 190  WasmDylinkInfo *info
 191) {
 192  const uint8_t WASM_MAGIC_NUMBER[4] = {0, 'a', 's', 'm'};
 193  const uint8_t WASM_VERSION[4] = {1, 0, 0, 0};
 194  const uint8_t WASM_CUSTOM_SECTION = 0x0;
 195  const uint8_t WASM_DYLINK_MEM_INFO = 0x1;
 196
 197  const uint8_t *p = bytes;
 198  const uint8_t *end = bytes + length;
 199
 200  if (length < 8) return false;
 201  if (memcmp(p, WASM_MAGIC_NUMBER, 4) != 0) return false;
 202  p += 4;
 203  if (memcmp(p, WASM_VERSION, 4) != 0) return false;
 204  p += 4;
 205
 206  while (p < end) {
 207    uint8_t section_id = read_u8(&p, end);
 208    uint32_t section_length = read_uleb128(&p, end);
 209    const uint8_t *section_end = p + section_length;
 210    if (section_end > end) return false;
 211
 212    if (section_id == WASM_CUSTOM_SECTION) {
 213      uint32_t name_length = read_uleb128(&p, section_end);
 214      const uint8_t *name_end = p + name_length;
 215      if (name_end > section_end) return false;
 216
 217      if (name_length == 8 && memcmp(p, "dylink.0", 8) == 0) {
 218        p = name_end;
 219        while (p < section_end) {
 220          uint8_t subsection_type = read_u8(&p, section_end);
 221          uint32_t subsection_size = read_uleb128(&p, section_end);
 222          const uint8_t *subsection_end = p + subsection_size;
 223          if (subsection_end > section_end) return false;
 224          if (subsection_type == WASM_DYLINK_MEM_INFO) {
 225            info->memory_size = read_uleb128(&p, subsection_end);
 226            info->memory_align = read_uleb128(&p, subsection_end);
 227            info->table_size = read_uleb128(&p, subsection_end);
 228            info->table_align = read_uleb128(&p, subsection_end);
 229            return true;
 230          }
 231          p = subsection_end;
 232        }
 233      }
 234    }
 235    p = section_end;
 236  }
 237  return false;
 238}
 239
 240/*******************************************
 241 * Native callbacks exposed to wasm modules
 242 *******************************************/
 243
 244 static wasm_trap_t *callback__abort(
 245  void *env,
 246  wasmtime_caller_t* caller,
 247  wasmtime_val_raw_t *args_and_results,
 248  size_t args_and_results_len
 249) {
 250  return wasmtime_trap_new("wasm module called abort", 24);
 251}
 252
 253static wasm_trap_t *callback__debug_message(
 254  void *env,
 255  wasmtime_caller_t* caller,
 256  wasmtime_val_raw_t *args_and_results,
 257  size_t args_and_results_len
 258) {
 259  wasmtime_context_t *context = wasmtime_caller_context(caller);
 260  TSWasmStore *store = env;
 261  assert(args_and_results_len == 2);
 262  uint32_t string_address = args_and_results[0].i32;
 263  uint32_t value = args_and_results[1].i32;
 264  uint8_t *memory = wasmtime_memory_data(context, &store->memory);
 265  printf("DEBUG: %s %u\n", &memory[string_address], value);
 266  return NULL;
 267}
 268
 269static wasm_trap_t *callback__noop(
 270  void *env,
 271  wasmtime_caller_t* caller,
 272  wasmtime_val_raw_t *args_and_results,
 273  size_t args_and_results_len
 274) {
 275  return NULL;
 276}
 277
 278static wasm_trap_t *callback__lexer_advance(
 279  void *env,
 280  wasmtime_caller_t* caller,
 281  wasmtime_val_raw_t *args_and_results,
 282  size_t args_and_results_len
 283) {
 284  wasmtime_context_t *context = wasmtime_caller_context(caller);
 285  assert(args_and_results_len == 2);
 286
 287  TSWasmStore *store = env;
 288  TSLexer *lexer = store->current_lexer;
 289  bool skip = args_and_results[1].i32;
 290  lexer->advance(lexer, skip);
 291
 292  uint8_t *memory = wasmtime_memory_data(context, &store->memory);
 293  memcpy(&memory[store->lexer_address], &lexer->lookahead, sizeof(lexer->lookahead));
 294  return NULL;
 295}
 296
 297static wasm_trap_t *callback__lexer_mark_end(
 298  void *env,
 299  wasmtime_caller_t* caller,
 300  wasmtime_val_raw_t *args_and_results,
 301  size_t args_and_results_len
 302) {
 303  TSWasmStore *store = env;
 304  TSLexer *lexer = store->current_lexer;
 305  lexer->mark_end(lexer);
 306  return NULL;
 307}
 308
 309static wasm_trap_t *callback__lexer_get_column(
 310  void *env,
 311  wasmtime_caller_t* caller,
 312  wasmtime_val_raw_t *args_and_results,
 313  size_t args_and_results_len
 314) {
 315  TSWasmStore *store = env;
 316  TSLexer *lexer = store->current_lexer;
 317  uint32_t result = lexer->get_column(lexer);
 318  args_and_results[0].i32 = result;
 319  return NULL;
 320}
 321
 322static wasm_trap_t *callback__lexer_is_at_included_range_start(
 323  void *env,
 324  wasmtime_caller_t* caller,
 325  wasmtime_val_raw_t *args_and_results,
 326  size_t args_and_results_len
 327) {
 328  TSWasmStore *store = env;
 329  TSLexer *lexer = store->current_lexer;
 330  bool result = lexer->is_at_included_range_start(lexer);
 331  args_and_results[0].i32 = result;
 332  return NULL;
 333}
 334
 335static wasm_trap_t *callback__lexer_eof(
 336  void *env,
 337  wasmtime_caller_t* caller,
 338  wasmtime_val_raw_t *args_and_results,
 339  size_t args_and_results_len
 340) {
 341  TSWasmStore *store = env;
 342  TSLexer *lexer = store->current_lexer;
 343  bool result = lexer->eof(lexer);
 344  args_and_results[0].i32 = result;
 345  return NULL;
 346}
 347
 348typedef struct {
 349  uint32_t *storage_location;
 350  wasmtime_func_unchecked_callback_t callback;
 351  wasm_functype_t *type;
 352} FunctionDefinition;
 353
 354static void *copy(const void *data, size_t size) {
 355  void *result = ts_malloc(size);
 356  memcpy(result, data, size);
 357  return result;
 358}
 359
 360static void *copy_unsized_static_array(
 361  const uint8_t *data,
 362  int32_t start_address,
 363  const int32_t all_addresses[],
 364  size_t address_count
 365) {
 366  int32_t end_address = 0;
 367  for (unsigned i = 0; i < address_count; i++) {
 368    if (all_addresses[i] > start_address) {
 369      if (!end_address || all_addresses[i] < end_address) {
 370        end_address = all_addresses[i];
 371      }
 372    }
 373  }
 374
 375  if (!end_address) return NULL;
 376  size_t size = end_address - start_address;
 377  void *result = ts_malloc(size);
 378  memcpy(result, &data[start_address], size);
 379  return result;
 380}
 381
 382static void *copy_strings(
 383  const uint8_t *data,
 384  int32_t array_address,
 385  size_t count,
 386  StringData *string_data
 387) {
 388  const char **result = ts_malloc(count * sizeof(char *));
 389  for (unsigned i = 0; i < count; i++) {
 390    int32_t address;
 391    memcpy(&address, &data[array_address + i * sizeof(address)], sizeof(address));
 392    if (address == 0) {
 393      result[i] = (const char *)-1;
 394    } else {
 395      const uint8_t *string = &data[address];
 396      uint32_t len = strlen((const char *)string);
 397      result[i] = (const char *)(uintptr_t)string_data->size;
 398      array_extend(string_data, len + 1, string);
 399    }
 400  }
 401  for (unsigned i = 0; i < count; i++) {
 402    if (result[i] == (const char *)-1) {
 403      result[i] = NULL;
 404    } else {
 405      result[i] = string_data->contents + (uintptr_t)result[i];
 406    }
 407  }
 408  return result;
 409}
 410
 411static bool name_eq(const wasm_name_t *name, const char *string) {
 412  return strncmp(string, name->data, name->size) == 0;
 413}
 414
 415static inline wasm_functype_t* wasm_functype_new_4_0(
 416  wasm_valtype_t* p1,
 417  wasm_valtype_t* p2,
 418  wasm_valtype_t* p3,
 419  wasm_valtype_t* p4
 420) {
 421  wasm_valtype_t* ps[4] = {p1, p2, p3, p4};
 422  wasm_valtype_vec_t params, results;
 423  wasm_valtype_vec_new(&params, 4, ps);
 424  wasm_valtype_vec_new_empty(&results);
 425  return wasm_functype_new(&params, &results);
 426}
 427
 428#define format(output, ...) \
 429  do { \
 430    size_t message_length = snprintf((char *)NULL, 0, __VA_ARGS__); \
 431    *output = ts_malloc(message_length + 1); \
 432    snprintf(*output, message_length + 1, __VA_ARGS__); \
 433  } while (0)
 434
 435WasmLanguageId *language_id_new() {
 436  WasmLanguageId *self = ts_malloc(sizeof(WasmLanguageId));
 437  self->is_language_deleted = false;
 438  self->ref_count = 1;
 439  return self;
 440}
 441
 442WasmLanguageId *language_id_clone(WasmLanguageId *self) {
 443  atomic_inc(&self->ref_count);
 444  return self;
 445}
 446
 447void language_id_delete(WasmLanguageId *self) {
 448  if (atomic_dec(&self->ref_count) == 0) {
 449    ts_free(self);
 450  }
 451}
 452
 453static wasmtime_extern_t get_builtin_extern(
 454  wasmtime_table_t *table,
 455  unsigned index
 456) {
 457  return (wasmtime_extern_t) {
 458    .kind = WASMTIME_EXTERN_FUNC,
 459    .of.func = (wasmtime_func_t) {
 460      .store_id = table->store_id,
 461      .index = index
 462    }
 463  };
 464}
 465
 466static bool ts_wasm_store__provide_builtin_import(
 467  TSWasmStore *self,
 468  const wasm_name_t *import_name,
 469  wasmtime_extern_t *import
 470) {
 471  wasmtime_error_t *error = NULL;
 472  wasmtime_context_t *context = wasmtime_store_context(self->store);
 473
 474  // Dynamic linking parameters
 475  if (name_eq(import_name, "__memory_base")) {
 476    wasmtime_val_t value = WASM_I32_VAL(self->current_memory_offset);
 477    wasmtime_global_t global;
 478    error = wasmtime_global_new(context, self->const_i32_type, &value, &global);
 479    assert(!error);
 480    *import = (wasmtime_extern_t) {.kind = WASMTIME_EXTERN_GLOBAL, .of.global = global};
 481  } else if (name_eq(import_name, "__table_base")) {
 482    wasmtime_val_t value = WASM_I32_VAL(self->current_function_table_offset);
 483    wasmtime_global_t global;
 484    error = wasmtime_global_new(context, self->const_i32_type, &value, &global);
 485    assert(!error);
 486    *import = (wasmtime_extern_t) {.kind = WASMTIME_EXTERN_GLOBAL, .of.global = global};
 487  } else if (name_eq(import_name, "__stack_pointer")) {
 488    *import = (wasmtime_extern_t) {.kind = WASMTIME_EXTERN_GLOBAL, .of.global = self->stack_pointer_global};
 489  } else if (name_eq(import_name, "__indirect_function_table")) {
 490    *import = (wasmtime_extern_t) {.kind = WASMTIME_EXTERN_TABLE, .of.table = self->function_table};
 491  } else if (name_eq(import_name, "memory")) {
 492    *import = (wasmtime_extern_t) {.kind = WASMTIME_EXTERN_MEMORY, .of.memory = self->memory};
 493  }
 494
 495  // Builtin functions
 496  else if (name_eq(import_name, "__assert_fail")) {
 497    *import = get_builtin_extern(&self->function_table, self->builtin_fn_indices.assert_fail);
 498  } else if (name_eq(import_name, "__cxa_atexit")) {
 499    *import = get_builtin_extern(&self->function_table, self->builtin_fn_indices.at_exit);
 500  } else if (name_eq(import_name, "args_get")) {
 501    *import = get_builtin_extern(&self->function_table, self->builtin_fn_indices.args_get);
 502  } else if (name_eq(import_name, "args_sizes_get")) {
 503    *import = get_builtin_extern(&self->function_table, self->builtin_fn_indices.args_sizes_get);
 504  } else if (name_eq(import_name, "abort")) {
 505    *import = get_builtin_extern(&self->function_table, self->builtin_fn_indices.abort);
 506  } else if (name_eq(import_name, "proc_exit")) {
 507    *import = get_builtin_extern(&self->function_table, self->builtin_fn_indices.proc_exit);
 508  } else if (name_eq(import_name, "emscripten_notify_memory_growth")) {
 509    *import = get_builtin_extern(&self->function_table, self->builtin_fn_indices.notify_memory_growth);
 510  } else if (name_eq(import_name, "tree_sitter_debug_message")) {
 511    *import = get_builtin_extern(&self->function_table, self->builtin_fn_indices.debug_message);
 512  } else {
 513    return false;
 514  }
 515
 516  return true;
 517}
 518
 519static bool ts_wasm_store__call_module_initializer(
 520  TSWasmStore *self,
 521  const wasm_name_t *export_name,
 522  wasmtime_extern_t *export,
 523  wasm_trap_t **trap
 524) {
 525  if (
 526    name_eq(export_name, "_initialize") ||
 527    name_eq(export_name, "__wasm_apply_data_relocs") ||
 528    name_eq(export_name, "__wasm_call_ctors")
 529  ) {
 530    wasmtime_context_t *context = wasmtime_store_context(self->store);
 531    wasmtime_func_t initialization_func = export->of.func;
 532    wasmtime_error_t *error = wasmtime_func_call(context, &initialization_func, NULL, 0, NULL, 0, trap);
 533    assert(!error);
 534    return true;
 535  } else {
 536    return false;
 537  }
 538}
 539
 540TSWasmStore *ts_wasm_store_new(TSWasmEngine *engine, TSWasmError *wasm_error) {
 541  TSWasmStore *self = ts_calloc(1, sizeof(TSWasmStore));
 542  wasmtime_store_t *store = wasmtime_store_new(engine, self, NULL);
 543  wasmtime_context_t *context = wasmtime_store_context(store);
 544  wasmtime_error_t *error = NULL;
 545  wasm_trap_t *trap = NULL;
 546  wasm_message_t message = WASM_EMPTY_VEC;
 547  wasm_exporttype_vec_t export_types = WASM_EMPTY_VEC;
 548  wasmtime_extern_t *imports = NULL;
 549  wasmtime_module_t *stdlib_module = NULL;
 550  wasm_memorytype_t *memory_type = NULL;
 551  wasm_tabletype_t *table_type = NULL;
 552
 553  // Define functions called by scanners via function pointers on the lexer.
 554  LexerInWasmMemory lexer = {
 555    .lookahead = 0,
 556    .result_symbol = 0,
 557  };
 558  FunctionDefinition lexer_definitions[] = {
 559    {
 560      (uint32_t *)&lexer.advance,
 561      callback__lexer_advance,
 562      wasm_functype_new_2_0(wasm_valtype_new_i32(), wasm_valtype_new_i32())
 563    },
 564    {
 565      (uint32_t *)&lexer.mark_end,
 566      callback__lexer_mark_end,
 567      wasm_functype_new_1_0(wasm_valtype_new_i32())
 568    },
 569    {
 570      (uint32_t *)&lexer.get_column,
 571      callback__lexer_get_column,
 572      wasm_functype_new_1_1(wasm_valtype_new_i32(), wasm_valtype_new_i32())
 573    },
 574    {
 575      (uint32_t *)&lexer.is_at_included_range_start,
 576      callback__lexer_is_at_included_range_start,
 577      wasm_functype_new_1_1(wasm_valtype_new_i32(), wasm_valtype_new_i32())
 578    },
 579    {
 580      (uint32_t *)&lexer.eof,
 581      callback__lexer_eof,
 582      wasm_functype_new_1_1(wasm_valtype_new_i32(), wasm_valtype_new_i32())
 583    },
 584  };
 585
 586  // Define builtin functions that can be imported by scanners.
 587  BuiltinFunctionIndices builtin_fn_indices;
 588  FunctionDefinition builtin_definitions[] = {
 589    {
 590      &builtin_fn_indices.proc_exit,
 591      callback__abort,
 592      wasm_functype_new_1_0(wasm_valtype_new_i32())
 593    },
 594    {
 595      &builtin_fn_indices.abort,
 596      callback__abort,
 597      wasm_functype_new_0_0()
 598    },
 599    {
 600      &builtin_fn_indices.assert_fail,
 601      callback__abort,
 602      wasm_functype_new_4_0(wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32())
 603    },
 604    {
 605      &builtin_fn_indices.notify_memory_growth,
 606      callback__noop,
 607      wasm_functype_new_1_0(wasm_valtype_new_i32())
 608    },
 609    {
 610      &builtin_fn_indices.debug_message,
 611      callback__debug_message,
 612      wasm_functype_new_2_0(wasm_valtype_new_i32(), wasm_valtype_new_i32())
 613    },
 614    {
 615      &builtin_fn_indices.at_exit,
 616      callback__noop,
 617      wasm_functype_new_3_1(wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32())
 618    },
 619    {
 620      &builtin_fn_indices.args_get,
 621      callback__noop,
 622      wasm_functype_new_2_1(wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32())
 623    },
 624    {
 625      &builtin_fn_indices.args_sizes_get,
 626      callback__noop,
 627      wasm_functype_new_2_1(wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32())
 628    },
 629  };
 630
 631  // Create all of the wasm functions.
 632  unsigned builtin_definitions_len = array_len(builtin_definitions);
 633  unsigned lexer_definitions_len = array_len(lexer_definitions);
 634  for (unsigned i = 0; i < builtin_definitions_len; i++) {
 635    FunctionDefinition *definition = &builtin_definitions[i];
 636    wasmtime_func_t func;
 637    wasmtime_func_new_unchecked(context, definition->type, definition->callback, self, NULL, &func);
 638    *definition->storage_location = func.index;
 639    wasm_functype_delete(definition->type);
 640  }
 641  for (unsigned i = 0; i < lexer_definitions_len; i++) {
 642    FunctionDefinition *definition = &lexer_definitions[i];
 643    wasmtime_func_t func;
 644    wasmtime_func_new_unchecked(context, definition->type, definition->callback, self, NULL, &func);
 645    *definition->storage_location = func.index;
 646    wasm_functype_delete(definition->type);
 647  }
 648
 649  // Compile the stdlib module.
 650  error = wasmtime_module_new(engine, STDLIB_WASM, STDLIB_WASM_LEN, &stdlib_module);
 651  if (error) {
 652    wasmtime_error_message(error, &message);
 653    wasm_error->kind = TSWasmErrorKindCompile;
 654    format(
 655      &wasm_error->message,
 656      "failed to compile wasm stdlib: %.*s",
 657      (int)message.size, message.data
 658    );
 659    goto error;
 660  }
 661
 662  // Retrieve the stdlib module's imports.
 663  wasm_importtype_vec_t import_types = WASM_EMPTY_VEC;
 664  wasmtime_module_imports(stdlib_module, &import_types);
 665
 666  // Find the initial number of memory pages needed by the stdlib.
 667  const wasm_memorytype_t *stdlib_memory_type;
 668  for (unsigned i = 0; i < import_types.size; i++) {
 669    wasm_importtype_t *import_type = import_types.data[i];
 670    const wasm_name_t *import_name = wasm_importtype_name(import_type);
 671    if (name_eq(import_name, "memory")) {
 672      const wasm_externtype_t *type = wasm_importtype_type(import_type);
 673      stdlib_memory_type = wasm_externtype_as_memorytype_const(type);
 674    }
 675  }
 676  if (!stdlib_memory_type) {
 677    wasm_error->kind = TSWasmErrorKindCompile;
 678    format(
 679      &wasm_error->message,
 680      "wasm stdlib is missing the 'memory' import"
 681    );
 682    goto error;
 683  }
 684
 685  // Initialize store's memory
 686  uint64_t initial_memory_pages = wasmtime_memorytype_minimum(stdlib_memory_type);
 687  wasm_limits_t memory_limits = {.min = initial_memory_pages, .max = MAX_MEMORY_SIZE};
 688  memory_type = wasm_memorytype_new(&memory_limits);
 689  wasmtime_memory_t memory;
 690  error = wasmtime_memory_new(context, memory_type, &memory);
 691  if (error) {
 692    wasmtime_error_message(error, &message);
 693    wasm_error->kind = TSWasmErrorKindAllocate;
 694    format(
 695      &wasm_error->message,
 696      "failed to allocate wasm memory: %.*s",
 697      (int)message.size, message.data
 698    );
 699    goto error;
 700  }
 701  wasm_memorytype_delete(memory_type);
 702  memory_type = NULL;
 703
 704  // Initialize store's function table
 705  wasm_limits_t table_limits = {.min = 1, .max = wasm_limits_max_default};
 706  table_type = wasm_tabletype_new(wasm_valtype_new(WASM_FUNCREF), &table_limits);
 707  wasmtime_val_t initializer = {.kind = WASMTIME_FUNCREF};
 708  wasmtime_table_t function_table;
 709  error = wasmtime_table_new(context, table_type, &initializer, &function_table);
 710  if (error) {
 711    wasmtime_error_message(error, &message);
 712    wasm_error->kind = TSWasmErrorKindAllocate;
 713    format(
 714      &wasm_error->message,
 715      "failed to allocate wasm table: %.*s",
 716      (int)message.size, message.data
 717    );
 718    goto error;
 719  }
 720  wasm_tabletype_delete(table_type);
 721  table_type = NULL;
 722
 723  unsigned stdlib_symbols_len = array_len(STDLIB_SYMBOLS);
 724
 725  // Define globals for the stack and heap start addresses.
 726  wasm_globaltype_t *const_i32_type = wasm_globaltype_new(wasm_valtype_new_i32(), WASM_CONST);
 727  wasm_globaltype_t *var_i32_type = wasm_globaltype_new(wasm_valtype_new_i32(), WASM_VAR);
 728
 729  wasmtime_val_t stack_pointer_value = WASM_I32_VAL(0);
 730  wasmtime_global_t stack_pointer_global;
 731  error = wasmtime_global_new(context, var_i32_type, &stack_pointer_value, &stack_pointer_global);
 732  assert(!error);
 733
 734  *self = (TSWasmStore) {
 735    .engine = engine,
 736    .store = store,
 737    .memory = memory,
 738    .function_table = function_table,
 739    .language_instances = array_new(),
 740    .stdlib_fn_indices = ts_calloc(stdlib_symbols_len, sizeof(uint32_t)),
 741    .builtin_fn_indices = builtin_fn_indices,
 742    .stack_pointer_global = stack_pointer_global,
 743    .current_memory_offset = 0,
 744    .current_function_table_offset = 0,
 745    .const_i32_type = const_i32_type,
 746  };
 747
 748  // Set up the imports for the stdlib module.
 749  imports = ts_calloc(import_types.size, sizeof(wasmtime_extern_t));
 750  for (unsigned i = 0; i < import_types.size; i++) {
 751    wasm_importtype_t *type = import_types.data[i];
 752    const wasm_name_t *import_name = wasm_importtype_name(type);
 753    if (!ts_wasm_store__provide_builtin_import(self, import_name, &imports[i])) {
 754      wasm_error->kind = TSWasmErrorKindInstantiate;
 755      format(
 756        &wasm_error->message,
 757        "unexpected import in wasm stdlib: %.*s\n",
 758        (int)import_name->size, import_name->data
 759      );
 760      goto error;
 761    }
 762  }
 763
 764  // Instantiate the stdlib module.
 765  wasmtime_instance_t instance;
 766  error = wasmtime_instance_new(context, stdlib_module, imports, import_types.size, &instance, &trap);
 767  ts_free(imports);
 768  imports = NULL;
 769  if (error) {
 770    wasmtime_error_message(error, &message);
 771    wasm_error->kind = TSWasmErrorKindInstantiate;
 772    format(
 773      &wasm_error->message,
 774      "failed to instantiate wasm stdlib module: %.*s",
 775      (int)message.size, message.data
 776    );
 777    goto error;
 778  }
 779  if (trap) {
 780    wasm_trap_message(trap, &message);
 781    wasm_error->kind = TSWasmErrorKindInstantiate;
 782    format(
 783      &wasm_error->message,
 784      "trapped when instantiating wasm stdlib module: %.*s",
 785      (int)message.size, message.data
 786    );
 787    goto error;
 788  }
 789  wasm_importtype_vec_delete(&import_types);
 790
 791  // Process the stdlib module's exports.
 792  for (unsigned i = 0; i < stdlib_symbols_len; i++) {
 793    self->stdlib_fn_indices[i] = UINT32_MAX;
 794  }
 795  wasmtime_module_exports(stdlib_module, &export_types);
 796  for (unsigned i = 0; i < export_types.size; i++) {
 797    wasm_exporttype_t *export_type = export_types.data[i];
 798    const wasm_name_t *name = wasm_exporttype_name(export_type);
 799
 800    char *export_name;
 801    size_t name_len;
 802    wasmtime_extern_t export = {.kind = WASM_EXTERN_GLOBAL};
 803    bool exists = wasmtime_instance_export_nth(context, &instance, i, &export_name, &name_len, &export);
 804    assert(exists);
 805
 806    if (export.kind == WASMTIME_EXTERN_GLOBAL) {
 807      if (name_eq(name, "__stack_pointer")) {
 808        self->stack_pointer_global = export.of.global;
 809      }
 810    }
 811
 812    if (export.kind == WASMTIME_EXTERN_FUNC) {
 813      if (ts_wasm_store__call_module_initializer(self, name, &export, &trap)) {
 814        if (trap) {
 815          wasm_trap_message(trap, &message);
 816          wasm_error->kind = TSWasmErrorKindInstantiate;
 817          format(
 818            &wasm_error->message,
 819            "trap when calling stdlib relocation function: %.*s\n",
 820            (int)message.size, message.data
 821          );
 822          goto error;
 823        }
 824        continue;
 825      }
 826
 827      if (name_eq(name, "reset_heap")) {
 828        self->builtin_fn_indices.reset_heap = export.of.func.index;
 829        continue;
 830      }
 831
 832      for (unsigned j = 0; j < stdlib_symbols_len; j++) {
 833        if (name_eq(name, STDLIB_SYMBOLS[j])) {
 834          self->stdlib_fn_indices[j] = export.of.func.index;
 835          break;
 836        }
 837      }
 838    }
 839  }
 840
 841  if (self->builtin_fn_indices.reset_heap == UINT32_MAX) {
 842    wasm_error->kind = TSWasmErrorKindInstantiate;
 843    format(
 844      &wasm_error->message,
 845      "missing malloc reset function in wasm stdlib"
 846    );
 847    goto error;
 848  }
 849
 850  for (unsigned i = 0; i < stdlib_symbols_len; i++) {
 851    if (self->stdlib_fn_indices[i] == UINT32_MAX) {
 852      wasm_error->kind = TSWasmErrorKindInstantiate;
 853      format(
 854        &wasm_error->message,
 855        "missing exported symbol in wasm stdlib: %s",
 856        STDLIB_SYMBOLS[i]
 857      );
 858      goto error;
 859    }
 860  }
 861
 862  wasm_exporttype_vec_delete(&export_types);
 863  wasmtime_module_delete(stdlib_module);
 864
 865  // Add all of the lexer callback functions to the function table. Store their function table
 866  // indices on the in-memory lexer.
 867  uint32_t table_index;
 868  error = wasmtime_table_grow(context, &function_table, lexer_definitions_len, &initializer, &table_index);
 869  if (error) {
 870    wasmtime_error_message(error, &message);
 871    wasm_error->kind = TSWasmErrorKindAllocate;
 872    format(
 873      &wasm_error->message,
 874      "failed to grow wasm table to initial size: %.*s",
 875      (int)message.size, message.data
 876    );
 877    goto error;
 878  }
 879  for (unsigned i = 0; i < lexer_definitions_len; i++) {
 880    FunctionDefinition *definition = &lexer_definitions[i];
 881    wasmtime_func_t func = {function_table.store_id, *definition->storage_location};
 882    wasmtime_val_t func_val = {.kind = WASMTIME_FUNCREF, .of.funcref = func};
 883    error = wasmtime_table_set(context, &function_table, table_index, &func_val);
 884    assert(!error);
 885    *(int32_t *)(definition->storage_location) = table_index;
 886    table_index++;
 887  }
 888
 889  self->current_function_table_offset = table_index;
 890  self->lexer_address = initial_memory_pages * MEMORY_PAGE_SIZE;
 891  self->serialization_buffer_address = self->lexer_address + sizeof(LexerInWasmMemory);
 892  self->current_memory_offset = self->serialization_buffer_address + TREE_SITTER_SERIALIZATION_BUFFER_SIZE;
 893
 894  // Grow the memory enough to hold the builtin lexer and serialization buffer.
 895  uint32_t new_pages_needed = (self->current_memory_offset - self->lexer_address - 1) / MEMORY_PAGE_SIZE + 1;
 896  uint64_t prev_memory_size;
 897  wasmtime_memory_grow(context, &memory, new_pages_needed, &prev_memory_size);
 898
 899  uint8_t *memory_data = wasmtime_memory_data(context, &memory);
 900  memcpy(&memory_data[self->lexer_address], &lexer, sizeof(lexer));
 901  return self;
 902
 903error:
 904  ts_free(self);
 905  if (stdlib_module) wasmtime_module_delete(stdlib_module);
 906  if (store) wasmtime_store_delete(store);
 907  if (import_types.size) wasm_importtype_vec_delete(&import_types);
 908  if (memory_type) wasm_memorytype_delete(memory_type);
 909  if (table_type) wasm_tabletype_delete(table_type);
 910  if (trap) wasm_trap_delete(trap);
 911  if (error) wasmtime_error_delete(error);
 912  if (message.size) wasm_byte_vec_delete(&message);
 913  if (export_types.size) wasm_exporttype_vec_delete(&export_types);
 914  if (imports) ts_free(imports);
 915  return NULL;
 916}
 917
 918void ts_wasm_store_delete(TSWasmStore *self) {
 919  if (!self) return;
 920  ts_free(self->stdlib_fn_indices);
 921  wasm_globaltype_delete(self->const_i32_type);
 922  wasmtime_store_delete(self->store);
 923  wasm_engine_delete(self->engine);
 924  for (unsigned i = 0; i < self->language_instances.size; i++) {
 925    LanguageWasmInstance *instance = &self->language_instances.contents[i];
 926    language_id_delete(instance->language_id);
 927  }
 928  array_delete(&self->language_instances);
 929  ts_free(self);
 930}
 931
 932size_t ts_wasm_store_language_count(const TSWasmStore *self) {
 933  size_t result = 0;
 934  for (unsigned i = 0; i < self->language_instances.size; i++) {
 935    const WasmLanguageId *id = self->language_instances.contents[i].language_id;
 936    if (!id->is_language_deleted) {
 937      result++;
 938    }
 939  }
 940  return result;
 941}
 942
 943static bool ts_wasm_store__instantiate(
 944  TSWasmStore *self,
 945  wasmtime_module_t *module,
 946  const char *language_name,
 947  const WasmDylinkInfo *dylink_info,
 948  wasmtime_instance_t *result,
 949  int32_t *language_address,
 950  char **error_message
 951) {
 952  wasmtime_error_t *error = NULL;
 953  wasm_trap_t *trap = NULL;
 954  wasm_message_t message = WASM_EMPTY_VEC;
 955  char *language_function_name = NULL;
 956  wasmtime_extern_t *imports = NULL;
 957  wasmtime_context_t *context = wasmtime_store_context(self->store);
 958
 959  // Grow the function table to make room for the new functions.
 960  wasmtime_val_t initializer = {.kind = WASMTIME_FUNCREF};
 961  uint32_t prev_table_size;
 962  error = wasmtime_table_grow(context, &self->function_table, dylink_info->table_size, &initializer, &prev_table_size);
 963  if (error) {
 964    format(error_message, "invalid function table size %u", dylink_info->table_size);
 965    goto error;
 966  }
 967
 968  // Grow the memory to make room for the new data.
 969  uint32_t needed_memory_size = self->current_memory_offset + dylink_info->memory_size;
 970  uint32_t current_memory_size = wasmtime_memory_data_size(context, &self->memory);
 971  if (needed_memory_size > current_memory_size) {
 972    uint32_t pages_to_grow = (
 973      needed_memory_size - current_memory_size + MEMORY_PAGE_SIZE - 1) /
 974      MEMORY_PAGE_SIZE;
 975    uint64_t prev_memory_size;
 976    error = wasmtime_memory_grow(context, &self->memory, pages_to_grow, &prev_memory_size);
 977    if (error) {
 978      format(error_message, "invalid memory size %u", dylink_info->memory_size);
 979      goto error;
 980    }
 981  }
 982
 983  // Construct the language function name as string.
 984  format(&language_function_name, "tree_sitter_%s", language_name);
 985
 986  const uint64_t store_id = self->function_table.store_id;
 987
 988  // Build the imports list for the module.
 989  wasm_importtype_vec_t import_types = WASM_EMPTY_VEC;
 990  wasmtime_module_imports(module, &import_types);
 991  imports = ts_calloc(import_types.size, sizeof(wasmtime_extern_t));
 992
 993  for (unsigned i = 0; i < import_types.size; i++) {
 994    const wasm_importtype_t *import_type = import_types.data[i];
 995    const wasm_name_t *import_name = wasm_importtype_name(import_type);
 996    if (import_name->size == 0) {
 997      format(error_message, "empty import name");
 998      goto error;
 999    }
1000
1001    if (ts_wasm_store__provide_builtin_import(self, import_name, &imports[i])) {
1002      continue;
1003    }
1004
1005    bool defined_in_stdlib = false;
1006    for (unsigned j = 0; j < array_len(STDLIB_SYMBOLS); j++) {
1007      if (name_eq(import_name, STDLIB_SYMBOLS[j])) {
1008        uint16_t address = self->stdlib_fn_indices[j];
1009        imports[i] = (wasmtime_extern_t) {.kind = WASMTIME_EXTERN_FUNC, .of.func = {store_id, address}};
1010        defined_in_stdlib = true;
1011        break;
1012      }
1013    }
1014
1015    if (!defined_in_stdlib) {
1016      format(
1017        error_message,
1018        "invalid import '%.*s'\n",
1019        (int)import_name->size, import_name->data
1020      );
1021      goto error;
1022    }
1023  }
1024
1025  wasmtime_instance_t instance;
1026  error = wasmtime_instance_new(context, module, imports, import_types.size, &instance, &trap);
1027  wasm_importtype_vec_delete(&import_types);
1028  ts_free(imports);
1029  imports = NULL;
1030  if (error) {
1031    wasmtime_error_message(error, &message);
1032    format(
1033      error_message,
1034      "error instantiating wasm module: %.*s\n",
1035      (int)message.size, message.data
1036    );
1037    goto error;
1038  }
1039  if (trap) {
1040    wasm_trap_message(trap, &message);
1041    format(
1042      error_message,
1043      "trap when instantiating wasm module: %.*s\n",
1044      (int)message.size, message.data
1045    );
1046    goto error;
1047  }
1048
1049  self->current_memory_offset += dylink_info->memory_size;
1050  self->current_function_table_offset += dylink_info->table_size;
1051
1052  // Process the module's exports.
1053  bool found_language = false;
1054  wasmtime_extern_t language_extern;
1055  wasm_exporttype_vec_t export_types = WASM_EMPTY_VEC;
1056  wasmtime_module_exports(module, &export_types);
1057  for (unsigned i = 0; i < export_types.size; i++) {
1058    wasm_exporttype_t *export_type = export_types.data[i];
1059    const wasm_name_t *name = wasm_exporttype_name(export_type);
1060
1061    size_t name_len;
1062    char *export_name;
1063    wasmtime_extern_t export = {.kind = WASM_EXTERN_GLOBAL};
1064    bool exists = wasmtime_instance_export_nth(context, &instance, i, &export_name, &name_len, &export);
1065    assert(exists);
1066
1067    // If the module exports an initialization or data-relocation function, call it.
1068    if (ts_wasm_store__call_module_initializer(self, name, &export, &trap)) {
1069      if (trap) {
1070        wasm_trap_message(trap, &message);
1071        format(
1072          error_message,
1073          "trap when calling data relocation function: %.*s\n",
1074          (int)message.size, message.data
1075        );
1076        goto error;
1077      }
1078    }
1079
1080    // Find the main language function for the module.
1081    else if (name_eq(name, language_function_name)) {
1082      language_extern = export;
1083      found_language = true;
1084    }
1085  }
1086  wasm_exporttype_vec_delete(&export_types);
1087
1088  if (!found_language) {
1089    format(
1090      error_message,
1091      "module did not contain language function: %s",
1092      language_function_name
1093    );
1094    goto error;
1095  }
1096
1097  // Invoke the language function to get the static address of the language object.
1098  wasmtime_func_t language_func = language_extern.of.func;
1099  wasmtime_val_t language_address_val;
1100  error = wasmtime_func_call(context, &language_func, NULL, 0, &language_address_val, 1, &trap);
1101  assert(!error);
1102  if (trap) {
1103    wasm_trap_message(trap, &message);
1104    format(
1105      error_message,
1106      "trapped when calling language function: %s: %.*s\n",
1107      language_function_name, (int)message.size, message.data
1108    );
1109    goto error;
1110  }
1111
1112  if (language_address_val.kind != WASMTIME_I32) {
1113    format(
1114      error_message,
1115      "language function did not return an integer: %s\n",
1116      language_function_name
1117    );
1118    goto error;
1119  }
1120
1121  ts_free(language_function_name);
1122  *result = instance;
1123  *language_address = language_address_val.of.i32;
1124  return true;
1125
1126error:
1127  if (language_function_name) ts_free(language_function_name);
1128  if (message.size) wasm_byte_vec_delete(&message);
1129  if (error) wasmtime_error_delete(error);
1130  if (trap) wasm_trap_delete(trap);
1131  if (imports) ts_free(imports);
1132  return false;
1133}
1134
1135static bool ts_wasm_store__sentinel_lex_fn(TSLexer *_lexer, TSStateId state) {
1136  return false;
1137}
1138
1139const TSLanguage *ts_wasm_store_load_language(
1140  TSWasmStore *self,
1141  const char *language_name,
1142  const char *wasm,
1143  uint32_t wasm_len,
1144  TSWasmError *wasm_error
1145) {
1146  WasmDylinkInfo dylink_info;
1147  wasmtime_module_t *module = NULL;
1148  wasmtime_error_t *error = NULL;
1149  wasm_error->kind = TSWasmErrorKindNone;
1150
1151  if (!wasm_dylink_info__parse((const unsigned char *)wasm, wasm_len, &dylink_info)) {
1152    wasm_error->kind = TSWasmErrorKindParse;
1153    format(&wasm_error->message, "failed to parse dylink section of wasm module");
1154    goto error;
1155  }
1156
1157  // Compile the wasm code.
1158  error = wasmtime_module_new(self->engine, (const uint8_t *)wasm, wasm_len, &module);
1159  if (error) {
1160    wasm_message_t message;
1161    wasmtime_error_message(error, &message);
1162    wasm_error->kind = TSWasmErrorKindCompile;
1163    format(&wasm_error->message, "error compiling wasm module: %.*s", (int)message.size, message.data);
1164    wasm_byte_vec_delete(&message);
1165    goto error;
1166  }
1167
1168  // Instantiate the module in this store.
1169  wasmtime_instance_t instance;
1170  int32_t language_address;
1171  if (!ts_wasm_store__instantiate(
1172    self,
1173    module,
1174    language_name,
1175    &dylink_info,
1176    &instance,
1177    &language_address,
1178    &wasm_error->message
1179  )) {
1180    wasm_error->kind = TSWasmErrorKindInstantiate;
1181    goto error;
1182  }
1183
1184  // Copy all of the static data out of the language object in wasm memory,
1185  // constructing a native language object.
1186  LanguageInWasmMemory wasm_language;
1187  wasmtime_context_t *context = wasmtime_store_context(self->store);
1188  const uint8_t *memory = wasmtime_memory_data(context, &self->memory);
1189  memcpy(&wasm_language, &memory[language_address], sizeof(LanguageInWasmMemory));
1190
1191  if (wasm_language.version < LANGUAGE_VERSION_USABLE_VIA_WASM) {
1192      wasm_error->kind = TSWasmErrorKindInstantiate;
1193      format(&wasm_error->message, "language version %u is too old for wasm", wasm_language.version);
1194      goto error;
1195  }
1196
1197  int32_t addresses[] = {
1198    wasm_language.alias_map,
1199    wasm_language.alias_sequences,
1200    wasm_language.field_map_entries,
1201    wasm_language.field_map_slices,
1202    wasm_language.field_names,
1203    wasm_language.keyword_lex_fn,
1204    wasm_language.lex_fn,
1205    wasm_language.lex_modes,
1206    wasm_language.parse_actions,
1207    wasm_language.parse_table,
1208    wasm_language.primary_state_ids,
1209    wasm_language.primary_state_ids,
1210    wasm_language.public_symbol_map,
1211    wasm_language.small_parse_table,
1212    wasm_language.small_parse_table_map,
1213    wasm_language.symbol_metadata,
1214    wasm_language.symbol_metadata,
1215    wasm_language.symbol_names,
1216    wasm_language.external_token_count > 0 ? wasm_language.external_scanner.states : 0,
1217    wasm_language.external_token_count > 0 ? wasm_language.external_scanner.symbol_map : 0,
1218    wasm_language.external_token_count > 0 ? wasm_language.external_scanner.create : 0,
1219    wasm_language.external_token_count > 0 ? wasm_language.external_scanner.destroy : 0,
1220    wasm_language.external_token_count > 0 ? wasm_language.external_scanner.scan : 0,
1221    wasm_language.external_token_count > 0 ? wasm_language.external_scanner.serialize : 0,
1222    wasm_language.external_token_count > 0 ? wasm_language.external_scanner.deserialize : 0,
1223    language_address,
1224    self->current_memory_offset,
1225  };
1226  uint32_t address_count = array_len(addresses);
1227
1228  TSLanguage *language = ts_calloc(1, sizeof(TSLanguage));
1229  StringData symbol_name_buffer = array_new();
1230  StringData field_name_buffer = array_new();
1231
1232  *language = (TSLanguage) {
1233    .version = wasm_language.version,
1234    .symbol_count = wasm_language.symbol_count,
1235    .alias_count = wasm_language.alias_count,
1236    .token_count = wasm_language.token_count,
1237    .external_token_count = wasm_language.external_token_count,
1238    .state_count = wasm_language.state_count,
1239    .large_state_count = wasm_language.large_state_count,
1240    .production_id_count = wasm_language.production_id_count,
1241    .field_count = wasm_language.field_count,
1242    .max_alias_sequence_length = wasm_language.max_alias_sequence_length,
1243    .keyword_capture_token = wasm_language.keyword_capture_token,
1244    .parse_table = copy(
1245      &memory[wasm_language.parse_table],
1246      wasm_language.large_state_count * wasm_language.symbol_count * sizeof(uint16_t)
1247    ),
1248    .parse_actions = copy_unsized_static_array(
1249      memory,
1250      wasm_language.parse_actions,
1251      addresses,
1252      address_count
1253    ),
1254    .symbol_names = copy_strings(
1255      memory,
1256      wasm_language.symbol_names,
1257      wasm_language.symbol_count + wasm_language.alias_count,
1258      &symbol_name_buffer
1259    ),
1260    .symbol_metadata = copy(
1261      &memory[wasm_language.symbol_metadata],
1262      (wasm_language.symbol_count + wasm_language.alias_count) * sizeof(TSSymbolMetadata)
1263    ),
1264    .public_symbol_map = copy(
1265      &memory[wasm_language.public_symbol_map],
1266      (wasm_language.symbol_count + wasm_language.alias_count) * sizeof(TSSymbol)
1267    ),
1268    .lex_modes = copy(
1269      &memory[wasm_language.lex_modes],
1270      wasm_language.state_count * sizeof(TSLexMode)
1271    ),
1272  };
1273
1274  if (language->field_count > 0 && language->production_id_count > 0) {
1275    language->field_map_slices = copy(
1276      &memory[wasm_language.field_map_slices],
1277      wasm_language.production_id_count * sizeof(TSFieldMapSlice)
1278    );
1279    const TSFieldMapSlice last_field_map_slice = language->field_map_slices[language->production_id_count - 1];
1280    language->field_map_entries = copy(
1281      &memory[wasm_language.field_map_entries],
1282      (last_field_map_slice.index + last_field_map_slice.length) * sizeof(TSFieldMapEntry)
1283    );
1284    language->field_names = copy_strings(
1285      memory,
1286      wasm_language.field_names,
1287      wasm_language.field_count + 1,
1288      &field_name_buffer
1289    );
1290  }
1291
1292  if (language->max_alias_sequence_length > 0 && language->production_id_count > 0) {
1293    // The alias map contains symbols, alias counts, and aliases, terminated by a null symbol.
1294    int32_t alias_map_size = 0;
1295    for (;;) {
1296      TSSymbol symbol;
1297      memcpy(&symbol, &memory[wasm_language.alias_map + alias_map_size], sizeof(symbol));
1298      alias_map_size += sizeof(TSSymbol);
1299      if (symbol == 0) break;
1300      uint16_t value_count;
1301      memcpy(&value_count, &memory[wasm_language.alias_map + alias_map_size], sizeof(value_count));
1302      alias_map_size += value_count * sizeof(TSSymbol);
1303    }
1304    language->alias_map = copy(
1305      &memory[wasm_language.alias_map],
1306      alias_map_size * sizeof(TSSymbol)
1307    );
1308    language->alias_sequences = copy(
1309      &memory[wasm_language.alias_sequences],
1310      wasm_language.production_id_count * wasm_language.max_alias_sequence_length * sizeof(TSSymbol)
1311    );
1312  }
1313
1314  if (language->state_count > language->large_state_count) {
1315    uint32_t small_state_count = wasm_language.state_count - wasm_language.large_state_count;
1316    language->small_parse_table_map = copy(
1317      &memory[wasm_language.small_parse_table_map],
1318      small_state_count * sizeof(uint32_t)
1319    );
1320    language->small_parse_table = copy_unsized_static_array(
1321      memory,
1322      wasm_language.small_parse_table,
1323      addresses,
1324      address_count
1325    );
1326  }
1327
1328  if (language->version >= LANGUAGE_VERSION_WITH_PRIMARY_STATES) {
1329    language->primary_state_ids = copy(
1330      &memory[wasm_language.primary_state_ids],
1331      wasm_language.state_count * sizeof(TSStateId)
1332    );
1333  }
1334
1335  if (language->external_token_count > 0) {
1336    language->external_scanner.symbol_map = copy(
1337      &memory[wasm_language.external_scanner.symbol_map],
1338      wasm_language.external_token_count * sizeof(TSSymbol)
1339    );
1340    language->external_scanner.states = (void *)(uintptr_t)wasm_language.external_scanner.states;
1341  }
1342
1343  unsigned name_len = strlen(language_name);
1344  char *name = ts_malloc(name_len + 1);
1345  memcpy(name, language_name, name_len);
1346  name[name_len] = '\0';
1347
1348  LanguageWasmModule *language_module = ts_malloc(sizeof(LanguageWasmModule));
1349  *language_module = (LanguageWasmModule) {
1350    .language_id = language_id_new(),
1351    .module = module,
1352    .name = name,
1353    .symbol_name_buffer = symbol_name_buffer.contents,
1354    .field_name_buffer = field_name_buffer.contents,
1355    .dylink_info = dylink_info,
1356    .ref_count = 1,
1357  };
1358
1359  // The lex functions are not used for wasm languages. Use those two fields
1360  // to mark this language as WASM-based and to store the language's
1361  // WASM-specific data.
1362  language->lex_fn = ts_wasm_store__sentinel_lex_fn;
1363  language->keyword_lex_fn = (void *)language_module;
1364
1365  // Clear out any instances of languages that have been deleted.
1366  for (unsigned i = 0; i < self->language_instances.size; i++) {
1367    WasmLanguageId *id = self->language_instances.contents[i].language_id;
1368    if (id->is_language_deleted) {
1369      language_id_delete(id);
1370      array_erase(&self->language_instances, i);
1371      i--;
1372    }
1373  }
1374
1375  // Store this store's instance of this language module.
1376  array_push(&self->language_instances, ((LanguageWasmInstance) {
1377    .language_id = language_id_clone(language_module->language_id),
1378    .instance = instance,
1379    .external_states_address = wasm_language.external_scanner.states,
1380    .lex_main_fn_index = wasm_language.lex_fn,
1381    .lex_keyword_fn_index = wasm_language.keyword_lex_fn,
1382    .scanner_create_fn_index = wasm_language.external_scanner.create,
1383    .scanner_destroy_fn_index = wasm_language.external_scanner.destroy,
1384    .scanner_serialize_fn_index = wasm_language.external_scanner.serialize,
1385    .scanner_deserialize_fn_index = wasm_language.external_scanner.deserialize,
1386    .scanner_scan_fn_index = wasm_language.external_scanner.scan,
1387  }));
1388
1389  return language;
1390
1391error:
1392  if (module) wasmtime_module_delete(module);
1393  return NULL;
1394}
1395
1396bool ts_wasm_store_add_language(
1397  TSWasmStore *self,
1398  const TSLanguage *language,
1399  uint32_t *index
1400) {
1401  wasmtime_context_t *context = wasmtime_store_context(self->store);
1402  const LanguageWasmModule *language_module = (void *)language->keyword_lex_fn;
1403
1404  // Search for this store's instance of the language module. Also clear out any
1405  // instances of languages that have been deleted.
1406  bool exists = false;
1407  for (unsigned i = 0; i < self->language_instances.size; i++) {
1408    WasmLanguageId *id = self->language_instances.contents[i].language_id;
1409    if (id->is_language_deleted) {
1410      language_id_delete(id);
1411      array_erase(&self->language_instances, i);
1412      i--;
1413    } else if (id == language_module->language_id) {
1414      exists = true;
1415      *index = i;
1416    }
1417  }
1418
1419  // If the language module has not been instantiated in this store, then add
1420  // it to this store.
1421  if (!exists) {
1422    *index = self->language_instances.size;
1423    char *message;
1424    wasmtime_instance_t instance;
1425    int32_t language_address;
1426    if (!ts_wasm_store__instantiate(
1427      self,
1428      language_module->module,
1429      language_module->name,
1430      &language_module->dylink_info,
1431      &instance,
1432      &language_address,
1433      &message
1434    )) {
1435      ts_free(message);
1436      return false;
1437    }
1438
1439    LanguageInWasmMemory wasm_language;
1440    const uint8_t *memory = wasmtime_memory_data(context, &self->memory);
1441    memcpy(&wasm_language, &memory[language_address], sizeof(LanguageInWasmMemory));
1442    array_push(&self->language_instances, ((LanguageWasmInstance) {
1443      .language_id = language_id_clone(language_module->language_id),
1444      .instance = instance,
1445      .external_states_address = wasm_language.external_scanner.states,
1446      .lex_main_fn_index = wasm_language.lex_fn,
1447      .lex_keyword_fn_index = wasm_language.keyword_lex_fn,
1448      .scanner_create_fn_index = wasm_language.external_scanner.create,
1449      .scanner_destroy_fn_index = wasm_language.external_scanner.destroy,
1450      .scanner_serialize_fn_index = wasm_language.external_scanner.serialize,
1451      .scanner_deserialize_fn_index = wasm_language.external_scanner.deserialize,
1452      .scanner_scan_fn_index = wasm_language.external_scanner.scan,
1453    }));
1454  }
1455
1456  return true;
1457}
1458
1459void ts_wasm_store_reset_heap(TSWasmStore *self) {
1460  wasmtime_context_t *context = wasmtime_store_context(self->store);
1461  wasmtime_func_t func = {
1462    self->function_table.store_id,
1463    self->builtin_fn_indices.reset_heap
1464  };
1465  wasm_trap_t *trap = NULL;
1466  wasmtime_val_t args[1] = {
1467    {.of.i32 = self->current_memory_offset, .kind = WASMTIME_I32},
1468  };
1469
1470  wasmtime_error_t *error = wasmtime_func_call(context, &func, args, 1, NULL, 0, &trap);
1471  assert(!error);
1472  assert(!trap);
1473}
1474
1475bool ts_wasm_store_start(TSWasmStore *self, TSLexer *lexer, const TSLanguage *language) {
1476  uint32_t instance_index;
1477  if (!ts_wasm_store_add_language(self, language, &instance_index)) return false;
1478  self->current_lexer = lexer;
1479  self->current_instance = &self->language_instances.contents[instance_index];
1480  self->has_error = false;
1481  ts_wasm_store_reset_heap(self);
1482  return true;
1483}
1484
1485void ts_wasm_store_reset(TSWasmStore *self) {
1486  self->current_lexer = NULL;
1487  self->current_instance = NULL;
1488  self->has_error = false;
1489  ts_wasm_store_reset_heap(self);
1490}
1491
1492static void ts_wasm_store__call(
1493  TSWasmStore *self,
1494  int32_t function_index,
1495  wasmtime_val_raw_t *args_and_results,
1496  size_t args_and_results_len
1497) {
1498  wasmtime_context_t *context = wasmtime_store_context(self->store);
1499  wasmtime_val_t value;
1500  bool succeeded = wasmtime_table_get(context, &self->function_table, function_index, &value);
1501  assert(succeeded);
1502  assert(value.kind == WASMTIME_FUNCREF);
1503  wasmtime_func_t func = value.of.funcref;
1504
1505  wasm_trap_t *trap = NULL;
1506  wasmtime_error_t *error = wasmtime_func_call_unchecked(context, &func, args_and_results, args_and_results_len, &trap);
1507  if (error) {
1508    // wasm_message_t message;
1509    // wasmtime_error_message(error, &message);
1510    // fprintf(
1511    //   stderr,
1512    //   "error in wasm module: %.*s\n",
1513    //   (int)message.size, message.data
1514    // );
1515    wasmtime_error_delete(error);
1516    self->has_error = true;
1517  } else if (trap) {
1518    // wasm_message_t message;
1519    // wasm_trap_message(trap, &message);
1520    // fprintf(
1521    //   stderr,
1522    //   "trap in wasm module: %.*s\n",
1523    //   (int)message.size, message.data
1524    // );
1525    wasm_trap_delete(trap);
1526    self->has_error = true;
1527  }
1528}
1529
1530static bool ts_wasm_store__call_lex_function(TSWasmStore *self, unsigned function_index, TSStateId state) {
1531  wasmtime_context_t *context = wasmtime_store_context(self->store);
1532  uint8_t *memory_data = wasmtime_memory_data(context, &self->memory);
1533  memcpy(
1534    &memory_data[self->lexer_address],
1535    &self->current_lexer->lookahead,
1536    sizeof(self->current_lexer->lookahead)
1537  );
1538
1539  wasmtime_val_raw_t args[2] = {
1540    {.i32 = self->lexer_address},
1541    {.i32 = state},
1542  };
1543  ts_wasm_store__call(self, function_index, args, 2);
1544  if (self->has_error) return false;
1545  bool result = args[0].i32;
1546
1547  memcpy(
1548    &self->current_lexer->lookahead,
1549    &memory_data[self->lexer_address],
1550    sizeof(self->current_lexer->lookahead) + sizeof(self->current_lexer->result_symbol)
1551  );
1552  return result;
1553}
1554
1555bool ts_wasm_store_call_lex_main(TSWasmStore *self, TSStateId state) {
1556  return ts_wasm_store__call_lex_function(
1557    self,
1558    self->current_instance->lex_main_fn_index,
1559    state
1560  );
1561}
1562
1563bool ts_wasm_store_call_lex_keyword(TSWasmStore *self, TSStateId state) {
1564  return ts_wasm_store__call_lex_function(
1565    self,
1566    self->current_instance->lex_keyword_fn_index,
1567    state
1568  );
1569}
1570
1571uint32_t ts_wasm_store_call_scanner_create(TSWasmStore *self) {
1572  wasmtime_val_raw_t args[1] = {{.i32 = 0}};
1573  ts_wasm_store__call(self, self->current_instance->scanner_create_fn_index, args, 1);
1574  if (self->has_error) return 0;
1575  return args[0].i32;
1576}
1577
1578void ts_wasm_store_call_scanner_destroy(TSWasmStore *self, uint32_t scanner_address) {
1579  if (self->current_instance) {
1580    wasmtime_val_raw_t args[1] = {{.i32 = scanner_address}};
1581    ts_wasm_store__call(self, self->current_instance->scanner_destroy_fn_index, args, 1);
1582  }
1583}
1584
1585bool ts_wasm_store_call_scanner_scan(
1586  TSWasmStore *self,
1587  uint32_t scanner_address,
1588  uint32_t valid_tokens_ix
1589) {
1590  wasmtime_context_t *context = wasmtime_store_context(self->store);
1591  uint8_t *memory_data = wasmtime_memory_data(context, &self->memory);
1592
1593  memcpy(
1594    &memory_data[self->lexer_address],
1595    &self->current_lexer->lookahead,
1596    sizeof(self->current_lexer->lookahead)
1597  );
1598
1599  uint32_t valid_tokens_address =
1600    self->current_instance->external_states_address +
1601    (valid_tokens_ix * sizeof(bool));
1602  wasmtime_val_raw_t args[3] = {
1603    {.i32 = scanner_address},
1604    {.i32 = self->lexer_address},
1605    {.i32 = valid_tokens_address}
1606  };
1607  ts_wasm_store__call(self, self->current_instance->scanner_scan_fn_index, args, 3);
1608  if (self->has_error) return false;
1609
1610  memcpy(
1611    &self->current_lexer->lookahead,
1612    &memory_data[self->lexer_address],
1613    sizeof(self->current_lexer->lookahead) + sizeof(self->current_lexer->result_symbol)
1614  );
1615  return args[0].i32;
1616}
1617
1618uint32_t ts_wasm_store_call_scanner_serialize(
1619  TSWasmStore *self,
1620  uint32_t scanner_address,
1621  char *buffer
1622) {
1623  wasmtime_context_t *context = wasmtime_store_context(self->store);
1624  uint8_t *memory_data = wasmtime_memory_data(context, &self->memory);
1625
1626  wasmtime_val_raw_t args[2] = {
1627    {.i32 = scanner_address},
1628    {.i32 = self->serialization_buffer_address},
1629  };
1630  ts_wasm_store__call(self, self->current_instance->scanner_serialize_fn_index, args, 2);
1631  if (self->has_error) return 0;
1632
1633  uint32_t length = args[0].i32;
1634
1635  if (length > 0) {
1636    memcpy(
1637      ((Lexer *)self->current_lexer)->debug_buffer,
1638      &memory_data[self->serialization_buffer_address],
1639      length
1640    );
1641  }
1642  return length;
1643}
1644
1645void ts_wasm_store_call_scanner_deserialize(
1646  TSWasmStore *self,
1647  uint32_t scanner_address,
1648  const char *buffer,
1649  unsigned length
1650) {
1651  wasmtime_context_t *context = wasmtime_store_context(self->store);
1652  uint8_t *memory_data = wasmtime_memory_data(context, &self->memory);
1653
1654  if (length > 0) {
1655    memcpy(
1656      &memory_data[self->serialization_buffer_address],
1657      buffer,
1658      length
1659    );
1660  }
1661
1662  wasmtime_val_raw_t args[3] = {
1663    {.i32 = scanner_address},
1664    {.i32 = self->serialization_buffer_address},
1665    {.i32 = length},
1666  };
1667  ts_wasm_store__call(self, self->current_instance->scanner_deserialize_fn_index, args, 3);
1668}
1669
1670bool ts_wasm_store_has_error(const TSWasmStore *self) {
1671  return self->has_error;
1672}
1673
1674bool ts_language_is_wasm(const TSLanguage *self) {
1675  return self->lex_fn == ts_wasm_store__sentinel_lex_fn;
1676}
1677
1678static inline LanguageWasmModule *ts_language__wasm_module(const TSLanguage *self) {
1679  return (LanguageWasmModule *)self->keyword_lex_fn;
1680}
1681
1682void ts_wasm_language_retain(const TSLanguage *self) {
1683  LanguageWasmModule *module = ts_language__wasm_module(self);
1684  assert(module->ref_count > 0);
1685  atomic_inc(&module->ref_count);
1686}
1687
1688void ts_wasm_language_release(const TSLanguage *self) {
1689  LanguageWasmModule *module = ts_language__wasm_module(self);
1690  assert(module->ref_count > 0);
1691  if (atomic_dec(&module->ref_count) == 0) {
1692    // Update the language id to reflect that the language is deleted. This allows any wasm stores
1693    // that hold wasm instances for this language to delete those instances.
1694    atomic_inc(&module->language_id->is_language_deleted);
1695    language_id_delete(module->language_id);
1696
1697    ts_free((void *)module->field_name_buffer);
1698    ts_free((void *)module->symbol_name_buffer);
1699    ts_free((void *)module->name);
1700    wasmtime_module_delete(module->module);
1701    ts_free(module);
1702
1703    ts_free((void *)self->alias_map);
1704    ts_free((void *)self->alias_sequences);
1705    ts_free((void *)self->external_scanner.symbol_map);
1706    ts_free((void *)self->field_map_entries);
1707    ts_free((void *)self->field_map_slices);
1708    ts_free((void *)self->field_names);
1709    ts_free((void *)self->lex_modes);
1710    ts_free((void *)self->parse_actions);
1711    ts_free((void *)self->parse_table);
1712    ts_free((void *)self->primary_state_ids);
1713    ts_free((void *)self->public_symbol_map);
1714    ts_free((void *)self->small_parse_table);
1715    ts_free((void *)self->small_parse_table_map);
1716    ts_free((void *)self->symbol_metadata);
1717    ts_free((void *)self->symbol_names);
1718    ts_free((void *)self);
1719  }
1720}
1721
1722#else
1723
1724// If the WASM feature is not enabled, define dummy versions of all of the
1725// wasm-related functions.
1726
1727void ts_wasm_store_delete(TSWasmStore *self) {
1728  (void)self;
1729}
1730
1731bool ts_wasm_store_start(
1732  TSWasmStore *self,
1733  TSLexer *lexer,
1734  const TSLanguage *language
1735) {
1736  (void)self;
1737  (void)lexer;
1738  (void)language;
1739  return false;
1740}
1741
1742void ts_wasm_store_reset(TSWasmStore *self) {
1743  (void)self;
1744}
1745
1746bool ts_wasm_store_call_lex_main(TSWasmStore *self, TSStateId state) {
1747  (void)self;
1748  (void)state;
1749  return false;
1750}
1751
1752bool ts_wasm_store_call_lex_keyword(TSWasmStore *self, TSStateId state) {
1753  (void)self;
1754  (void)state;
1755  return false;
1756}
1757
1758uint32_t ts_wasm_store_call_scanner_create(TSWasmStore *self) {
1759  (void)self;
1760  return 0;
1761}
1762
1763void ts_wasm_store_call_scanner_destroy(
1764  TSWasmStore *self,
1765  uint32_t scanner_address
1766) {
1767  (void)self;
1768  (void)scanner_address;
1769}
1770
1771bool ts_wasm_store_call_scanner_scan(
1772  TSWasmStore *self,
1773  uint32_t scanner_address,
1774  uint32_t valid_tokens_ix
1775) {
1776  (void)self;
1777  (void)scanner_address;
1778  (void)valid_tokens_ix;
1779  return false;
1780}
1781
1782uint32_t ts_wasm_store_call_scanner_serialize(
1783  TSWasmStore *self,
1784  uint32_t scanner_address,
1785  char *buffer
1786) {
1787  (void)self;
1788  (void)scanner_address;
1789  (void)buffer;
1790  return 0;
1791}
1792
1793void ts_wasm_store_call_scanner_deserialize(
1794  TSWasmStore *self,
1795  uint32_t scanner_address,
1796  const char *buffer,
1797  unsigned length
1798) {
1799  (void)self;
1800  (void)scanner_address;
1801  (void)buffer;
1802  (void)length;
1803}
1804
1805bool ts_wasm_store_has_error(const TSWasmStore *self) {
1806  (void)self;
1807  return false;
1808}
1809
1810bool ts_language_is_wasm(const TSLanguage *self) {
1811  (void)self;
1812  return false;
1813}
1814
1815void ts_wasm_language_retain(const TSLanguage *self) {
1816  (void)self;
1817}
1818
1819void ts_wasm_language_release(const TSLanguage *self) {
1820  (void)self;
1821}
1822
1823#endif