diff --git a/CMakeLists.txt b/CMakeLists.txt index fa4305d9e..e2af188d3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -238,6 +238,14 @@ if(MLX_ENGINE) FRAMEWORK DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX ) + if(TARGET jaccl) + install(TARGETS jaccl + RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX + LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX + FRAMEWORK DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX + ) + endif() + # Install the Metal library for macOS arm64 (must be colocated with the binary) # Metal backend is only built for arm64, not x86_64 if(APPLE AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") diff --git a/MLX_C_VERSION b/MLX_C_VERSION index 5dc11c479..b3511376b 100644 --- a/MLX_C_VERSION +++ b/MLX_C_VERSION @@ -1 +1 @@ -0726ca922fc902c4c61ef9c27d94132be418e945 +fba4470b89073180056c9ea46c443051375f7399 diff --git a/MLX_VERSION b/MLX_VERSION index 12d7c829a..02237b6f6 100644 --- a/MLX_VERSION +++ b/MLX_VERSION @@ -1 +1 @@ -38ad257088fb2193ad47e527cf6534a689f30943 +e8ebdebeeb655feaa85a51f6b24ece5b6d5518d1 diff --git a/llm/status.go b/llm/status.go index c2dd6ac26..ed611e9ca 100644 --- a/llm/status.go +++ b/llm/status.go @@ -65,7 +65,12 @@ func (w *StatusWriter) AppendError(msg string) { // logs, add a small rolling buffer here to capture those fragments. var errorPrefixes = []string{ + "mlx:", + "MLX:", + "panic:", + "fatal error:", "error:", + "Error:", "CUDA error", "ROCm error", "cudaMalloc failed", @@ -79,15 +84,21 @@ var errorPrefixes = []string{ func (w *StatusWriter) Write(b []byte) (int, error) { var errMsg string + errStart := -1 + var errPrefix string for _, prefix := range errorPrefixes { - if _, after, ok := bytes.Cut(b, []byte(prefix)); ok { - line := after - if j := bytes.IndexByte(line, '\n'); j >= 0 { - line = line[:j] - } - errMsg = prefix + string(bytes.TrimRight(line, " \t\r")) + if i := bytes.Index(b, []byte(prefix)); i >= 0 && (errStart < 0 || i < errStart) { + errStart = i + errPrefix = prefix } } + if errStart >= 0 { + line := b[errStart+len(errPrefix):] + if j := bytes.IndexByte(line, '\n'); j >= 0 { + line = line[:j] + } + errMsg = errPrefix + string(bytes.TrimRight(line, " \t\r")) + } if errMsg != "" { w.AppendError(errMsg) } diff --git a/llm/status_test.go b/llm/status_test.go index 2297ddd25..200d9f262 100644 --- a/llm/status_test.go +++ b/llm/status_test.go @@ -1,35 +1,59 @@ package llm import ( - "os" + "io" "testing" ) func TestStatusWriterCapturesErrorLine(t *testing.T) { - f, err := os.CreateTemp(t.TempDir(), "status-writer") - if err != nil { - t.Fatal(err) - } - defer f.Close() - - w := NewStatusWriter(f) - if _, err := w.Write([]byte("llama_init_from_model: failed to initialize the context: failed to initialize Metal backend\n")); err != nil { - t.Fatal(err) + tests := []struct { + name string + log string + want string + }{ + { + name: "llama init", + log: "llama_init_from_model: failed to initialize the context: failed to initialize Metal backend\n", + want: "llama_init_from_model: failed to initialize the context: failed to initialize Metal backend", + }, + { + name: "cobra error", + log: "Error: foo baz bar\n", + want: "Error: foo baz bar", + }, + { + name: "uppercase mlx", + log: "MLX: there was an error\n", + want: "MLX: there was an error", + }, + { + name: "panic header", + log: "time=2026-05-01T15:36:45.053Z level=INFO source=pipeline.go:71 msg=\"peak memory\" size=\"8.26 GiB\"\n" + + "panic: mlx: Failed to compile kernel: nvrtc: error: invalid value for --gpu-architecture (-arch)\n" + + "\t. at /go/src/github.com/ollama/ollama/build/_deps/mlx-c-src/mlx/c/transforms.cpp:15\n\n" + + "goroutine 31 [running]:\n" + + "golang.org/x/sync/errgroup.(*Group).Go.func1()\n" + + "\tgolang.org/x/sync@v0.17.0/errgroup/errgroup.go:93 +0x50\n", + want: "panic: mlx: Failed to compile kernel: nvrtc: error: invalid value for --gpu-architecture (-arch)", + }, } - if got, want := w.LastError(), "llama_init_from_model: failed to initialize the context: failed to initialize Metal backend"; got != want { - t.Fatalf("LastError = %q, want %q", got, want) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := NewStatusWriter(io.Discard) + if _, err := w.Write([]byte(tt.log)); err != nil { + t.Fatal(err) + } + + if got := w.LastError(); got != tt.want { + t.Fatalf("LastError = %q, want %q", got, tt.want) + } + }) } } func TestStatusWriterAccumulatesErrorLines(t *testing.T) { - f, err := os.CreateTemp(t.TempDir(), "status-writer") - if err != nil { - t.Fatal(err) - } - defer f.Close() - - w := NewStatusWriter(f) + w := NewStatusWriter(io.Discard) if _, err := w.Write([]byte("error: failed to initialize the Metal library\n")); err != nil { t.Fatal(err) } diff --git a/scripts/build_darwin.sh b/scripts/build_darwin.sh index f53e1c484..5962fad04 100755 --- a/scripts/build_darwin.sh +++ b/scripts/build_darwin.sh @@ -233,9 +233,11 @@ _build_macapp() { cp "$VARIANT$LIB" "$DEST/" fi done - # Copy remaining files (metallib) from arm64 v3 + # Copy remaining files (metallib and auxiliary runtime dylibs) + # from arm64 v3. libmlx/libmlxc are handled above so v3 can + # be universal when an x86_64 build is available. for F in "$VARIANT"*; do - case "$(basename "$F")" in *.dylib) continue ;; esac + case "$(basename "$F")" in libmlx.dylib|libmlxc.dylib) continue ;; esac [ -f "$F" ] && [ ! -L "$F" ] || continue cp "$F" "$DEST/" done diff --git a/x/imagegen/mlx/mlx.c b/x/imagegen/mlx/mlx.c index 8d7ec0e0a..6352929cf 100644 --- a/x/imagegen/mlx/mlx.c +++ b/x/imagegen/mlx/mlx.c @@ -162,11 +162,13 @@ int (*mlx_distributed_recv_ptr)(mlx_array* res, const int* shape, size_t shape_n int (*mlx_distributed_recv_like_ptr)(mlx_array* res, const mlx_array x, int src, const mlx_distributed_group group , const mlx_stream s) = NULL; int (*mlx_distributed_send_ptr)(mlx_array* res, const mlx_array x, int dst, const mlx_distributed_group group , const mlx_stream s) = NULL; int (*mlx_distributed_sum_scatter_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) = NULL; +mlx_distributed_group (*mlx_distributed_group_new_ptr)(void) = NULL; +int (*mlx_distributed_group_free_ptr)(mlx_distributed_group group) = NULL; +int (*mlx_distributed_init_ptr)(mlx_distributed_group* res, bool strict, const char* bk) = NULL; int (*mlx_distributed_group_rank_ptr)(mlx_distributed_group group) = NULL; int (*mlx_distributed_group_size_ptr)(mlx_distributed_group group) = NULL; -mlx_distributed_group (*mlx_distributed_group_split_ptr)(mlx_distributed_group group, int color, int key) = NULL; +int (*mlx_distributed_group_split_ptr)(mlx_distributed_group* res, mlx_distributed_group group, int color, int key) = NULL; bool (*mlx_distributed_is_available_ptr)(const char* bk) = NULL; -mlx_distributed_group (*mlx_distributed_init_ptr)(bool strict, const char* bk) = NULL; void (*mlx_set_error_handler_ptr)(mlx_error_handler_func handler, void* data, void (*dtor)(void*)) = NULL; void (*_mlx_error_ptr)(const char* file, const int line, const char* fmt, ...) = NULL; int (*mlx_export_function_ptr)(const char* file, const mlx_closure fun, const mlx_vector_array args, bool shapeless) = NULL; @@ -210,26 +212,35 @@ int (*mlx_fast_rms_norm_ptr)(mlx_array* res, const mlx_array x, const mlx_array int (*mlx_fast_rope_ptr)(mlx_array* res, const mlx_array x, int dims, bool traditional, mlx_optional_float base, float scale, int offset, const mlx_array freqs , const mlx_stream s) = NULL; int (*mlx_fast_rope_dynamic_ptr)(mlx_array* res, const mlx_array x, int dims, bool traditional, mlx_optional_float base, float scale, const mlx_array offset, const mlx_array freqs , const mlx_stream s) = NULL; int (*mlx_fast_scaled_dot_product_attention_ptr)(mlx_array* res, const mlx_array queries, const mlx_array keys, const mlx_array values, float scale, const char* mask_mode, const mlx_array mask_arr , const mlx_array sinks , const mlx_stream s) = NULL; -int (*mlx_fft_fft_ptr)(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) = NULL; -int (*mlx_fft_fft2_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL; -int (*mlx_fft_fftn_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL; +int (*mlx_fft_fft_ptr)(mlx_array* res, const mlx_array a, int n, int axis, mlx_fft_norm norm, const mlx_stream s) = NULL; +int (*mlx_fft_fft2_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, mlx_fft_norm norm, const mlx_stream s) = NULL; +int (*mlx_fft_fftfreq_ptr)(mlx_array* res, int n, double d, const mlx_stream s) = NULL; +int (*mlx_fft_fftn_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, mlx_fft_norm norm, const mlx_stream s) = NULL; int (*mlx_fft_fftshift_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) = NULL; -int (*mlx_fft_ifft_ptr)(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) = NULL; -int (*mlx_fft_ifft2_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL; -int (*mlx_fft_ifftn_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL; +int (*mlx_fft_ifft_ptr)(mlx_array* res, const mlx_array a, int n, int axis, mlx_fft_norm norm, const mlx_stream s) = NULL; +int (*mlx_fft_ifft2_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, mlx_fft_norm norm, const mlx_stream s) = NULL; +int (*mlx_fft_ifftn_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, mlx_fft_norm norm, const mlx_stream s) = NULL; int (*mlx_fft_ifftshift_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) = NULL; -int (*mlx_fft_irfft_ptr)(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) = NULL; -int (*mlx_fft_irfft2_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL; -int (*mlx_fft_irfftn_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL; -int (*mlx_fft_rfft_ptr)(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) = NULL; -int (*mlx_fft_rfft2_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL; -int (*mlx_fft_rfftn_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL; +int (*mlx_fft_irfft_ptr)(mlx_array* res, const mlx_array a, int n, int axis, mlx_fft_norm norm, const mlx_stream s) = NULL; +int (*mlx_fft_irfft2_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, mlx_fft_norm norm, const mlx_stream s) = NULL; +int (*mlx_fft_irfftn_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, mlx_fft_norm norm, const mlx_stream s) = NULL; +int (*mlx_fft_rfft_ptr)(mlx_array* res, const mlx_array a, int n, int axis, mlx_fft_norm norm, const mlx_stream s) = NULL; +int (*mlx_fft_rfft2_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, mlx_fft_norm norm, const mlx_stream s) = NULL; +int (*mlx_fft_rfftfreq_ptr)(mlx_array* res, int n, double d, const mlx_stream s) = NULL; +int (*mlx_fft_rfftn_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, mlx_fft_norm norm, const mlx_stream s) = NULL; +int (*mlx_node_namer_free_ptr)(mlx_node_namer namer) = NULL; +int (*mlx_node_namer_set_name_ptr)(mlx_node_namer namer, const mlx_array arr, const char* name) = NULL; +int (*mlx_node_namer_get_name_ptr)(const char** name, mlx_node_namer namer, const mlx_array arr) = NULL; +int (*mlx_export_to_dot_ptr)(FILE* os, const mlx_node_namer namer, const mlx_vector_array outputs) = NULL; +int (*mlx_print_graph_ptr)(FILE* os, const mlx_node_namer namer, const mlx_vector_array outputs) = NULL; int (*mlx_load_reader_ptr)(mlx_array* res, mlx_io_reader in_stream, const mlx_stream s) = NULL; int (*mlx_load_ptr)(mlx_array* res, const char* file, const mlx_stream s) = NULL; +int (*mlx_load_gguf_ptr)(mlx_io_gguf* gguf, const char* file, const mlx_stream s) = NULL; int (*mlx_load_safetensors_reader_ptr)(mlx_map_string_to_array* res_0, mlx_map_string_to_string* res_1, mlx_io_reader in_stream, const mlx_stream s) = NULL; int (*mlx_load_safetensors_ptr)(mlx_map_string_to_array* res_0, mlx_map_string_to_string* res_1, const char* file, const mlx_stream s) = NULL; int (*mlx_save_writer_ptr)(mlx_io_writer out_stream, const mlx_array a) = NULL; int (*mlx_save_ptr)(const char* file, const mlx_array a) = NULL; +int (*mlx_save_gguf_ptr)(const char* file, mlx_io_gguf gguf) = NULL; int (*mlx_save_safetensors_writer_ptr)(mlx_io_writer in_stream, const mlx_map_string_to_array param, const mlx_map_string_to_string metadata) = NULL; int (*mlx_save_safetensors_ptr)(const char* file, const mlx_map_string_to_array param, const mlx_map_string_to_string metadata) = NULL; mlx_io_reader (*mlx_io_reader_new_ptr)(void* desc, mlx_io_vtable vtable) = NULL; @@ -240,6 +251,20 @@ mlx_io_writer (*mlx_io_writer_new_ptr)(void* desc, mlx_io_vtable vtable) = NULL; int (*mlx_io_writer_descriptor_ptr)(void** desc_, mlx_io_writer io) = NULL; int (*mlx_io_writer_tostring_ptr)(mlx_string* str_, mlx_io_writer io) = NULL; int (*mlx_io_writer_free_ptr)(mlx_io_writer io) = NULL; +mlx_io_gguf (*mlx_io_gguf_new_ptr)(void) = NULL; +int (*mlx_io_gguf_free_ptr)(mlx_io_gguf io) = NULL; +int (*mlx_io_gguf_get_keys_ptr)(mlx_vector_string* keys, mlx_io_gguf io) = NULL; +int (*mlx_io_gguf_get_array_ptr)(mlx_array* arr, mlx_io_gguf io, const char* key) = NULL; +int (*mlx_io_gguf_get_metadata_array_ptr)(mlx_array* arr, mlx_io_gguf io, const char* key) = NULL; +int (*mlx_io_gguf_get_metadata_string_ptr)(mlx_string* str, mlx_io_gguf io, const char* key) = NULL; +int (*mlx_io_gguf_get_metadata_vector_string_ptr)(mlx_vector_string* vstr, mlx_io_gguf io, const char* key) = NULL; +int (*mlx_io_gguf_has_metadata_array_ptr)(bool* flag, mlx_io_gguf io, const char* key) = NULL; +int (*mlx_io_gguf_has_metadata_string_ptr)(bool* flag, mlx_io_gguf io, const char* key) = NULL; +int (*mlx_io_gguf_has_metadata_vector_string_ptr)(bool* flag, mlx_io_gguf io, const char* key) = NULL; +int (*mlx_io_gguf_set_array_ptr)(mlx_io_gguf io, const char* key, const mlx_array arr) = NULL; +int (*mlx_io_gguf_set_metadata_array_ptr)(mlx_io_gguf io, const char* key, const mlx_array marr) = NULL; +int (*mlx_io_gguf_set_metadata_string_ptr)(mlx_io_gguf io, const char* key, const char* mstr) = NULL; +int (*mlx_io_gguf_set_metadata_vector_string_ptr)(mlx_io_gguf io, const char* key, const mlx_vector_string mvstr) = NULL; int (*mlx_linalg_cholesky_ptr)(mlx_array* res, const mlx_array a, bool upper, const mlx_stream s) = NULL; int (*mlx_linalg_cholesky_inv_ptr)(mlx_array* res, const mlx_array a, bool upper, const mlx_stream s) = NULL; int (*mlx_linalg_cross_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int axis, const mlx_stream s) = NULL; @@ -474,6 +499,10 @@ int (*mlx_slice_ptr)(mlx_array* res, const mlx_array a, const int* start, size_t int (*mlx_slice_dynamic_ptr)(mlx_array* res, const mlx_array a, const mlx_array start, const int* axes, size_t axes_num, const int* slice_size, size_t slice_size_num, const mlx_stream s) = NULL; int (*mlx_slice_update_ptr)(mlx_array* res, const mlx_array src, const mlx_array update, const int* start, size_t start_num, const int* stop, size_t stop_num, const int* strides, size_t strides_num, const mlx_stream s) = NULL; int (*mlx_slice_update_dynamic_ptr)(mlx_array* res, const mlx_array src, const mlx_array update, const mlx_array start, const int* axes, size_t axes_num, const mlx_stream s) = NULL; +int (*mlx_slice_update_add_ptr)(mlx_array* res, const mlx_array src, const mlx_array update, const int* start, size_t start_num, const int* stop, size_t stop_num, const int* strides, size_t strides_num, const mlx_stream s) = NULL; +int (*mlx_slice_update_max_ptr)(mlx_array* res, const mlx_array src, const mlx_array update, const int* start, size_t start_num, const int* stop, size_t stop_num, const int* strides, size_t strides_num, const mlx_stream s) = NULL; +int (*mlx_slice_update_min_ptr)(mlx_array* res, const mlx_array src, const mlx_array update, const int* start, size_t start_num, const int* stop, size_t stop_num, const int* strides, size_t strides_num, const mlx_stream s) = NULL; +int (*mlx_slice_update_prod_ptr)(mlx_array* res, const mlx_array src, const mlx_array update, const int* start, size_t start_num, const int* stop, size_t stop_num, const int* strides, size_t strides_num, const mlx_stream s) = NULL; int (*mlx_softmax_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool precise, const mlx_stream s) = NULL; int (*mlx_softmax_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool precise, const mlx_stream s) = NULL; int (*mlx_softmax_ptr)(mlx_array* res, const mlx_array a, bool precise, const mlx_stream s) = NULL; @@ -1320,6 +1349,21 @@ int mlx_load_functions(void* handle) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_sum_scatter\n"); return -1; } + mlx_distributed_group_new_ptr = GET_SYM(handle, "mlx_distributed_group_new"); + if (mlx_distributed_group_new_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_group_new\n"); + return -1; + } + mlx_distributed_group_free_ptr = GET_SYM(handle, "mlx_distributed_group_free"); + if (mlx_distributed_group_free_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_group_free\n"); + return -1; + } + mlx_distributed_init_ptr = GET_SYM(handle, "mlx_distributed_init"); + if (mlx_distributed_init_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_init\n"); + return -1; + } mlx_distributed_group_rank_ptr = GET_SYM(handle, "mlx_distributed_group_rank"); if (mlx_distributed_group_rank_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_group_rank\n"); @@ -1340,11 +1384,6 @@ int mlx_load_functions(void* handle) { fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_is_available\n"); return -1; } - mlx_distributed_init_ptr = GET_SYM(handle, "mlx_distributed_init"); - if (mlx_distributed_init_ptr == NULL) { - fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_init\n"); - return -1; - } mlx_set_error_handler_ptr = GET_SYM(handle, "mlx_set_error_handler"); if (mlx_set_error_handler_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_set_error_handler\n"); @@ -1570,6 +1609,11 @@ int mlx_load_functions(void* handle) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_fft2\n"); return -1; } + mlx_fft_fftfreq_ptr = GET_SYM(handle, "mlx_fft_fftfreq"); + if (mlx_fft_fftfreq_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_fftfreq\n"); + return -1; + } mlx_fft_fftn_ptr = GET_SYM(handle, "mlx_fft_fftn"); if (mlx_fft_fftn_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_fftn\n"); @@ -1625,11 +1669,41 @@ int mlx_load_functions(void* handle) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_rfft2\n"); return -1; } + mlx_fft_rfftfreq_ptr = GET_SYM(handle, "mlx_fft_rfftfreq"); + if (mlx_fft_rfftfreq_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_rfftfreq\n"); + return -1; + } mlx_fft_rfftn_ptr = GET_SYM(handle, "mlx_fft_rfftn"); if (mlx_fft_rfftn_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_rfftn\n"); return -1; } + mlx_node_namer_free_ptr = GET_SYM(handle, "mlx_node_namer_free"); + if (mlx_node_namer_free_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_node_namer_free\n"); + return -1; + } + mlx_node_namer_set_name_ptr = GET_SYM(handle, "mlx_node_namer_set_name"); + if (mlx_node_namer_set_name_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_node_namer_set_name\n"); + return -1; + } + mlx_node_namer_get_name_ptr = GET_SYM(handle, "mlx_node_namer_get_name"); + if (mlx_node_namer_get_name_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_node_namer_get_name\n"); + return -1; + } + mlx_export_to_dot_ptr = GET_SYM(handle, "mlx_export_to_dot"); + if (mlx_export_to_dot_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_export_to_dot\n"); + return -1; + } + mlx_print_graph_ptr = GET_SYM(handle, "mlx_print_graph"); + if (mlx_print_graph_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_print_graph\n"); + return -1; + } mlx_load_reader_ptr = GET_SYM(handle, "mlx_load_reader"); if (mlx_load_reader_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_load_reader\n"); @@ -1640,6 +1714,11 @@ int mlx_load_functions(void* handle) { fprintf(stderr, "MLX: Failed to load symbol: mlx_load\n"); return -1; } + mlx_load_gguf_ptr = GET_SYM(handle, "mlx_load_gguf"); + if (mlx_load_gguf_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_load_gguf\n"); + return -1; + } mlx_load_safetensors_reader_ptr = GET_SYM(handle, "mlx_load_safetensors_reader"); if (mlx_load_safetensors_reader_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_load_safetensors_reader\n"); @@ -1660,6 +1739,11 @@ int mlx_load_functions(void* handle) { fprintf(stderr, "MLX: Failed to load symbol: mlx_save\n"); return -1; } + mlx_save_gguf_ptr = GET_SYM(handle, "mlx_save_gguf"); + if (mlx_save_gguf_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_save_gguf\n"); + return -1; + } mlx_save_safetensors_writer_ptr = GET_SYM(handle, "mlx_save_safetensors_writer"); if (mlx_save_safetensors_writer_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_save_safetensors_writer\n"); @@ -1710,6 +1794,76 @@ int mlx_load_functions(void* handle) { fprintf(stderr, "MLX: Failed to load symbol: mlx_io_writer_free\n"); return -1; } + mlx_io_gguf_new_ptr = GET_SYM(handle, "mlx_io_gguf_new"); + if (mlx_io_gguf_new_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_io_gguf_new\n"); + return -1; + } + mlx_io_gguf_free_ptr = GET_SYM(handle, "mlx_io_gguf_free"); + if (mlx_io_gguf_free_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_io_gguf_free\n"); + return -1; + } + mlx_io_gguf_get_keys_ptr = GET_SYM(handle, "mlx_io_gguf_get_keys"); + if (mlx_io_gguf_get_keys_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_io_gguf_get_keys\n"); + return -1; + } + mlx_io_gguf_get_array_ptr = GET_SYM(handle, "mlx_io_gguf_get_array"); + if (mlx_io_gguf_get_array_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_io_gguf_get_array\n"); + return -1; + } + mlx_io_gguf_get_metadata_array_ptr = GET_SYM(handle, "mlx_io_gguf_get_metadata_array"); + if (mlx_io_gguf_get_metadata_array_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_io_gguf_get_metadata_array\n"); + return -1; + } + mlx_io_gguf_get_metadata_string_ptr = GET_SYM(handle, "mlx_io_gguf_get_metadata_string"); + if (mlx_io_gguf_get_metadata_string_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_io_gguf_get_metadata_string\n"); + return -1; + } + mlx_io_gguf_get_metadata_vector_string_ptr = GET_SYM(handle, "mlx_io_gguf_get_metadata_vector_string"); + if (mlx_io_gguf_get_metadata_vector_string_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_io_gguf_get_metadata_vector_string\n"); + return -1; + } + mlx_io_gguf_has_metadata_array_ptr = GET_SYM(handle, "mlx_io_gguf_has_metadata_array"); + if (mlx_io_gguf_has_metadata_array_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_io_gguf_has_metadata_array\n"); + return -1; + } + mlx_io_gguf_has_metadata_string_ptr = GET_SYM(handle, "mlx_io_gguf_has_metadata_string"); + if (mlx_io_gguf_has_metadata_string_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_io_gguf_has_metadata_string\n"); + return -1; + } + mlx_io_gguf_has_metadata_vector_string_ptr = GET_SYM(handle, "mlx_io_gguf_has_metadata_vector_string"); + if (mlx_io_gguf_has_metadata_vector_string_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_io_gguf_has_metadata_vector_string\n"); + return -1; + } + mlx_io_gguf_set_array_ptr = GET_SYM(handle, "mlx_io_gguf_set_array"); + if (mlx_io_gguf_set_array_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_io_gguf_set_array\n"); + return -1; + } + mlx_io_gguf_set_metadata_array_ptr = GET_SYM(handle, "mlx_io_gguf_set_metadata_array"); + if (mlx_io_gguf_set_metadata_array_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_io_gguf_set_metadata_array\n"); + return -1; + } + mlx_io_gguf_set_metadata_string_ptr = GET_SYM(handle, "mlx_io_gguf_set_metadata_string"); + if (mlx_io_gguf_set_metadata_string_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_io_gguf_set_metadata_string\n"); + return -1; + } + mlx_io_gguf_set_metadata_vector_string_ptr = GET_SYM(handle, "mlx_io_gguf_set_metadata_vector_string"); + if (mlx_io_gguf_set_metadata_vector_string_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_io_gguf_set_metadata_vector_string\n"); + return -1; + } mlx_linalg_cholesky_ptr = GET_SYM(handle, "mlx_linalg_cholesky"); if (mlx_linalg_cholesky_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_cholesky\n"); @@ -2880,6 +3034,26 @@ int mlx_load_functions(void* handle) { fprintf(stderr, "MLX: Failed to load symbol: mlx_slice_update_dynamic\n"); return -1; } + mlx_slice_update_add_ptr = GET_SYM(handle, "mlx_slice_update_add"); + if (mlx_slice_update_add_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_slice_update_add\n"); + return -1; + } + mlx_slice_update_max_ptr = GET_SYM(handle, "mlx_slice_update_max"); + if (mlx_slice_update_max_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_slice_update_max\n"); + return -1; + } + mlx_slice_update_min_ptr = GET_SYM(handle, "mlx_slice_update_min"); + if (mlx_slice_update_min_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_slice_update_min\n"); + return -1; + } + mlx_slice_update_prod_ptr = GET_SYM(handle, "mlx_slice_update_prod"); + if (mlx_slice_update_prod_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_slice_update_prod\n"); + return -1; + } mlx_softmax_axes_ptr = GET_SYM(handle, "mlx_softmax_axes"); if (mlx_softmax_axes_ptr == NULL) { fprintf(stderr, "MLX: Failed to load symbol: mlx_softmax_axes\n"); @@ -4144,6 +4318,18 @@ int mlx_distributed_sum_scatter(mlx_array* res, const mlx_array x, const mlx_dis return mlx_distributed_sum_scatter_ptr(res, x, group, s); } +mlx_distributed_group mlx_distributed_group_new(void) { + return mlx_distributed_group_new_ptr(); +} + +int mlx_distributed_group_free(mlx_distributed_group group) { + return mlx_distributed_group_free_ptr(group); +} + +int mlx_distributed_init(mlx_distributed_group* res, bool strict, const char* bk) { + return mlx_distributed_init_ptr(res, strict, bk); +} + int mlx_distributed_group_rank(mlx_distributed_group group) { return mlx_distributed_group_rank_ptr(group); } @@ -4152,18 +4338,14 @@ int mlx_distributed_group_size(mlx_distributed_group group) { return mlx_distributed_group_size_ptr(group); } -mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key) { - return mlx_distributed_group_split_ptr(group, color, key); +int mlx_distributed_group_split(mlx_distributed_group* res, mlx_distributed_group group, int color, int key) { + return mlx_distributed_group_split_ptr(res, group, color, key); } bool mlx_distributed_is_available(const char* bk) { return mlx_distributed_is_available_ptr(bk); } -mlx_distributed_group mlx_distributed_init(bool strict, const char* bk) { - return mlx_distributed_init_ptr(strict, bk); -} - void mlx_set_error_handler(mlx_error_handler_func handler, void* data, void (*dtor)(void*)) { mlx_set_error_handler_ptr(handler, data, dtor); } @@ -4336,60 +4518,88 @@ int mlx_fast_scaled_dot_product_attention(mlx_array* res, const mlx_array querie return mlx_fast_scaled_dot_product_attention_ptr(res, queries, keys, values, scale, mask_mode, mask_arr, sinks, s); } -int mlx_fft_fft(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) { - return mlx_fft_fft_ptr(res, a, n, axis, s); +int mlx_fft_fft(mlx_array* res, const mlx_array a, int n, int axis, mlx_fft_norm norm, const mlx_stream s) { + return mlx_fft_fft_ptr(res, a, n, axis, norm, s); } -int mlx_fft_fft2(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) { - return mlx_fft_fft2_ptr(res, a, n, n_num, axes, axes_num, s); +int mlx_fft_fft2(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, mlx_fft_norm norm, const mlx_stream s) { + return mlx_fft_fft2_ptr(res, a, n, n_num, axes, axes_num, norm, s); } -int mlx_fft_fftn(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) { - return mlx_fft_fftn_ptr(res, a, n, n_num, axes, axes_num, s); +int mlx_fft_fftfreq(mlx_array* res, int n, double d, const mlx_stream s) { + return mlx_fft_fftfreq_ptr(res, n, d, s); +} + +int mlx_fft_fftn(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, mlx_fft_norm norm, const mlx_stream s) { + return mlx_fft_fftn_ptr(res, a, n, n_num, axes, axes_num, norm, s); } int mlx_fft_fftshift(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) { return mlx_fft_fftshift_ptr(res, a, axes, axes_num, s); } -int mlx_fft_ifft(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) { - return mlx_fft_ifft_ptr(res, a, n, axis, s); +int mlx_fft_ifft(mlx_array* res, const mlx_array a, int n, int axis, mlx_fft_norm norm, const mlx_stream s) { + return mlx_fft_ifft_ptr(res, a, n, axis, norm, s); } -int mlx_fft_ifft2(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) { - return mlx_fft_ifft2_ptr(res, a, n, n_num, axes, axes_num, s); +int mlx_fft_ifft2(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, mlx_fft_norm norm, const mlx_stream s) { + return mlx_fft_ifft2_ptr(res, a, n, n_num, axes, axes_num, norm, s); } -int mlx_fft_ifftn(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) { - return mlx_fft_ifftn_ptr(res, a, n, n_num, axes, axes_num, s); +int mlx_fft_ifftn(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, mlx_fft_norm norm, const mlx_stream s) { + return mlx_fft_ifftn_ptr(res, a, n, n_num, axes, axes_num, norm, s); } int mlx_fft_ifftshift(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) { return mlx_fft_ifftshift_ptr(res, a, axes, axes_num, s); } -int mlx_fft_irfft(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) { - return mlx_fft_irfft_ptr(res, a, n, axis, s); +int mlx_fft_irfft(mlx_array* res, const mlx_array a, int n, int axis, mlx_fft_norm norm, const mlx_stream s) { + return mlx_fft_irfft_ptr(res, a, n, axis, norm, s); } -int mlx_fft_irfft2(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) { - return mlx_fft_irfft2_ptr(res, a, n, n_num, axes, axes_num, s); +int mlx_fft_irfft2(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, mlx_fft_norm norm, const mlx_stream s) { + return mlx_fft_irfft2_ptr(res, a, n, n_num, axes, axes_num, norm, s); } -int mlx_fft_irfftn(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) { - return mlx_fft_irfftn_ptr(res, a, n, n_num, axes, axes_num, s); +int mlx_fft_irfftn(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, mlx_fft_norm norm, const mlx_stream s) { + return mlx_fft_irfftn_ptr(res, a, n, n_num, axes, axes_num, norm, s); } -int mlx_fft_rfft(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) { - return mlx_fft_rfft_ptr(res, a, n, axis, s); +int mlx_fft_rfft(mlx_array* res, const mlx_array a, int n, int axis, mlx_fft_norm norm, const mlx_stream s) { + return mlx_fft_rfft_ptr(res, a, n, axis, norm, s); } -int mlx_fft_rfft2(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) { - return mlx_fft_rfft2_ptr(res, a, n, n_num, axes, axes_num, s); +int mlx_fft_rfft2(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, mlx_fft_norm norm, const mlx_stream s) { + return mlx_fft_rfft2_ptr(res, a, n, n_num, axes, axes_num, norm, s); } -int mlx_fft_rfftn(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) { - return mlx_fft_rfftn_ptr(res, a, n, n_num, axes, axes_num, s); +int mlx_fft_rfftfreq(mlx_array* res, int n, double d, const mlx_stream s) { + return mlx_fft_rfftfreq_ptr(res, n, d, s); +} + +int mlx_fft_rfftn(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, mlx_fft_norm norm, const mlx_stream s) { + return mlx_fft_rfftn_ptr(res, a, n, n_num, axes, axes_num, norm, s); +} + +int mlx_node_namer_free(mlx_node_namer namer) { + return mlx_node_namer_free_ptr(namer); +} + +int mlx_node_namer_set_name(mlx_node_namer namer, const mlx_array arr, const char* name) { + return mlx_node_namer_set_name_ptr(namer, arr, name); +} + +int mlx_node_namer_get_name(const char** name, mlx_node_namer namer, const mlx_array arr) { + return mlx_node_namer_get_name_ptr(name, namer, arr); +} + +int mlx_export_to_dot(FILE* os, const mlx_node_namer namer, const mlx_vector_array outputs) { + return mlx_export_to_dot_ptr(os, namer, outputs); +} + +int mlx_print_graph(FILE* os, const mlx_node_namer namer, const mlx_vector_array outputs) { + return mlx_print_graph_ptr(os, namer, outputs); } int mlx_load_reader(mlx_array* res, mlx_io_reader in_stream, const mlx_stream s) { @@ -4400,6 +4610,10 @@ int mlx_load(mlx_array* res, const char* file, const mlx_stream s) { return mlx_load_ptr(res, file, s); } +int mlx_load_gguf(mlx_io_gguf* gguf, const char* file, const mlx_stream s) { + return mlx_load_gguf_ptr(gguf, file, s); +} + int mlx_load_safetensors_reader(mlx_map_string_to_array* res_0, mlx_map_string_to_string* res_1, mlx_io_reader in_stream, const mlx_stream s) { return mlx_load_safetensors_reader_ptr(res_0, res_1, in_stream, s); } @@ -4416,6 +4630,10 @@ int mlx_save(const char* file, const mlx_array a) { return mlx_save_ptr(file, a); } +int mlx_save_gguf(const char* file, mlx_io_gguf gguf) { + return mlx_save_gguf_ptr(file, gguf); +} + int mlx_save_safetensors_writer(mlx_io_writer in_stream, const mlx_map_string_to_array param, const mlx_map_string_to_string metadata) { return mlx_save_safetensors_writer_ptr(in_stream, param, metadata); } @@ -4456,6 +4674,62 @@ int mlx_io_writer_free(mlx_io_writer io) { return mlx_io_writer_free_ptr(io); } +mlx_io_gguf mlx_io_gguf_new(void) { + return mlx_io_gguf_new_ptr(); +} + +int mlx_io_gguf_free(mlx_io_gguf io) { + return mlx_io_gguf_free_ptr(io); +} + +int mlx_io_gguf_get_keys(mlx_vector_string* keys, mlx_io_gguf io) { + return mlx_io_gguf_get_keys_ptr(keys, io); +} + +int mlx_io_gguf_get_array(mlx_array* arr, mlx_io_gguf io, const char* key) { + return mlx_io_gguf_get_array_ptr(arr, io, key); +} + +int mlx_io_gguf_get_metadata_array(mlx_array* arr, mlx_io_gguf io, const char* key) { + return mlx_io_gguf_get_metadata_array_ptr(arr, io, key); +} + +int mlx_io_gguf_get_metadata_string(mlx_string* str, mlx_io_gguf io, const char* key) { + return mlx_io_gguf_get_metadata_string_ptr(str, io, key); +} + +int mlx_io_gguf_get_metadata_vector_string(mlx_vector_string* vstr, mlx_io_gguf io, const char* key) { + return mlx_io_gguf_get_metadata_vector_string_ptr(vstr, io, key); +} + +int mlx_io_gguf_has_metadata_array(bool* flag, mlx_io_gguf io, const char* key) { + return mlx_io_gguf_has_metadata_array_ptr(flag, io, key); +} + +int mlx_io_gguf_has_metadata_string(bool* flag, mlx_io_gguf io, const char* key) { + return mlx_io_gguf_has_metadata_string_ptr(flag, io, key); +} + +int mlx_io_gguf_has_metadata_vector_string(bool* flag, mlx_io_gguf io, const char* key) { + return mlx_io_gguf_has_metadata_vector_string_ptr(flag, io, key); +} + +int mlx_io_gguf_set_array(mlx_io_gguf io, const char* key, const mlx_array arr) { + return mlx_io_gguf_set_array_ptr(io, key, arr); +} + +int mlx_io_gguf_set_metadata_array(mlx_io_gguf io, const char* key, const mlx_array marr) { + return mlx_io_gguf_set_metadata_array_ptr(io, key, marr); +} + +int mlx_io_gguf_set_metadata_string(mlx_io_gguf io, const char* key, const char* mstr) { + return mlx_io_gguf_set_metadata_string_ptr(io, key, mstr); +} + +int mlx_io_gguf_set_metadata_vector_string(mlx_io_gguf io, const char* key, const mlx_vector_string mvstr) { + return mlx_io_gguf_set_metadata_vector_string_ptr(io, key, mvstr); +} + int mlx_linalg_cholesky(mlx_array* res, const mlx_array a, bool upper, const mlx_stream s) { return mlx_linalg_cholesky_ptr(res, a, upper, s); } @@ -5392,6 +5666,22 @@ int mlx_slice_update_dynamic(mlx_array* res, const mlx_array src, const mlx_arra return mlx_slice_update_dynamic_ptr(res, src, update, start, axes, axes_num, s); } +int mlx_slice_update_add(mlx_array* res, const mlx_array src, const mlx_array update, const int* start, size_t start_num, const int* stop, size_t stop_num, const int* strides, size_t strides_num, const mlx_stream s) { + return mlx_slice_update_add_ptr(res, src, update, start, start_num, stop, stop_num, strides, strides_num, s); +} + +int mlx_slice_update_max(mlx_array* res, const mlx_array src, const mlx_array update, const int* start, size_t start_num, const int* stop, size_t stop_num, const int* strides, size_t strides_num, const mlx_stream s) { + return mlx_slice_update_max_ptr(res, src, update, start, start_num, stop, stop_num, strides, strides_num, s); +} + +int mlx_slice_update_min(mlx_array* res, const mlx_array src, const mlx_array update, const int* start, size_t start_num, const int* stop, size_t stop_num, const int* strides, size_t strides_num, const mlx_stream s) { + return mlx_slice_update_min_ptr(res, src, update, start, start_num, stop, stop_num, strides, strides_num, s); +} + +int mlx_slice_update_prod(mlx_array* res, const mlx_array src, const mlx_array update, const int* start, size_t start_num, const int* stop, size_t stop_num, const int* strides, size_t strides_num, const mlx_stream s) { + return mlx_slice_update_prod_ptr(res, src, update, start, start_num, stop, stop_num, strides, strides_num, s); +} + int mlx_softmax_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool precise, const mlx_stream s) { return mlx_softmax_axes_ptr(res, a, axes, axes_num, precise, s); } diff --git a/x/imagegen/mlx/mlx.h b/x/imagegen/mlx/mlx.h index 3f53c8941..3f7e12a95 100644 --- a/x/imagegen/mlx/mlx.h +++ b/x/imagegen/mlx/mlx.h @@ -152,11 +152,13 @@ #undef mlx_distributed_recv_like #undef mlx_distributed_send #undef mlx_distributed_sum_scatter +#undef mlx_distributed_group_new +#undef mlx_distributed_group_free +#undef mlx_distributed_init #undef mlx_distributed_group_rank #undef mlx_distributed_group_size #undef mlx_distributed_group_split #undef mlx_distributed_is_available -#undef mlx_distributed_init #undef mlx_set_error_handler #undef _mlx_error #undef mlx_export_function @@ -202,6 +204,7 @@ #undef mlx_fast_scaled_dot_product_attention #undef mlx_fft_fft #undef mlx_fft_fft2 +#undef mlx_fft_fftfreq #undef mlx_fft_fftn #undef mlx_fft_fftshift #undef mlx_fft_ifft @@ -213,13 +216,21 @@ #undef mlx_fft_irfftn #undef mlx_fft_rfft #undef mlx_fft_rfft2 +#undef mlx_fft_rfftfreq #undef mlx_fft_rfftn +#undef mlx_node_namer_free +#undef mlx_node_namer_set_name +#undef mlx_node_namer_get_name +#undef mlx_export_to_dot +#undef mlx_print_graph #undef mlx_load_reader #undef mlx_load +#undef mlx_load_gguf #undef mlx_load_safetensors_reader #undef mlx_load_safetensors #undef mlx_save_writer #undef mlx_save +#undef mlx_save_gguf #undef mlx_save_safetensors_writer #undef mlx_save_safetensors #undef mlx_io_reader_new @@ -230,6 +241,20 @@ #undef mlx_io_writer_descriptor #undef mlx_io_writer_tostring #undef mlx_io_writer_free +#undef mlx_io_gguf_new +#undef mlx_io_gguf_free +#undef mlx_io_gguf_get_keys +#undef mlx_io_gguf_get_array +#undef mlx_io_gguf_get_metadata_array +#undef mlx_io_gguf_get_metadata_string +#undef mlx_io_gguf_get_metadata_vector_string +#undef mlx_io_gguf_has_metadata_array +#undef mlx_io_gguf_has_metadata_string +#undef mlx_io_gguf_has_metadata_vector_string +#undef mlx_io_gguf_set_array +#undef mlx_io_gguf_set_metadata_array +#undef mlx_io_gguf_set_metadata_string +#undef mlx_io_gguf_set_metadata_vector_string #undef mlx_linalg_cholesky #undef mlx_linalg_cholesky_inv #undef mlx_linalg_cross @@ -464,6 +489,10 @@ #undef mlx_slice_dynamic #undef mlx_slice_update #undef mlx_slice_update_dynamic +#undef mlx_slice_update_add +#undef mlx_slice_update_max +#undef mlx_slice_update_min +#undef mlx_slice_update_prod #undef mlx_softmax_axes #undef mlx_softmax_axis #undef mlx_softmax @@ -752,11 +781,13 @@ extern int (*mlx_distributed_recv_ptr)(mlx_array* res, const int* shape, size_t extern int (*mlx_distributed_recv_like_ptr)(mlx_array* res, const mlx_array x, int src, const mlx_distributed_group group , const mlx_stream s); extern int (*mlx_distributed_send_ptr)(mlx_array* res, const mlx_array x, int dst, const mlx_distributed_group group , const mlx_stream s); extern int (*mlx_distributed_sum_scatter_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s); +extern mlx_distributed_group (*mlx_distributed_group_new_ptr)(void); +extern int (*mlx_distributed_group_free_ptr)(mlx_distributed_group group); +extern int (*mlx_distributed_init_ptr)(mlx_distributed_group* res, bool strict, const char* bk); extern int (*mlx_distributed_group_rank_ptr)(mlx_distributed_group group); extern int (*mlx_distributed_group_size_ptr)(mlx_distributed_group group); -extern mlx_distributed_group (*mlx_distributed_group_split_ptr)(mlx_distributed_group group, int color, int key); +extern int (*mlx_distributed_group_split_ptr)(mlx_distributed_group* res, mlx_distributed_group group, int color, int key); extern bool (*mlx_distributed_is_available_ptr)(const char* bk); -extern mlx_distributed_group (*mlx_distributed_init_ptr)(bool strict, const char* bk); extern void (*mlx_set_error_handler_ptr)(mlx_error_handler_func handler, void* data, void (*dtor)(void*)); extern void (*_mlx_error_ptr)(const char* file, const int line, const char* fmt, ...); extern int (*mlx_export_function_ptr)(const char* file, const mlx_closure fun, const mlx_vector_array args, bool shapeless); @@ -800,26 +831,35 @@ extern int (*mlx_fast_rms_norm_ptr)(mlx_array* res, const mlx_array x, const mlx extern int (*mlx_fast_rope_ptr)(mlx_array* res, const mlx_array x, int dims, bool traditional, mlx_optional_float base, float scale, int offset, const mlx_array freqs , const mlx_stream s); extern int (*mlx_fast_rope_dynamic_ptr)(mlx_array* res, const mlx_array x, int dims, bool traditional, mlx_optional_float base, float scale, const mlx_array offset, const mlx_array freqs , const mlx_stream s); extern int (*mlx_fast_scaled_dot_product_attention_ptr)(mlx_array* res, const mlx_array queries, const mlx_array keys, const mlx_array values, float scale, const char* mask_mode, const mlx_array mask_arr , const mlx_array sinks , const mlx_stream s); -extern int (*mlx_fft_fft_ptr)(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s); -extern int (*mlx_fft_fft2_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s); -extern int (*mlx_fft_fftn_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s); +extern int (*mlx_fft_fft_ptr)(mlx_array* res, const mlx_array a, int n, int axis, mlx_fft_norm norm, const mlx_stream s); +extern int (*mlx_fft_fft2_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, mlx_fft_norm norm, const mlx_stream s); +extern int (*mlx_fft_fftfreq_ptr)(mlx_array* res, int n, double d, const mlx_stream s); +extern int (*mlx_fft_fftn_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, mlx_fft_norm norm, const mlx_stream s); extern int (*mlx_fft_fftshift_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s); -extern int (*mlx_fft_ifft_ptr)(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s); -extern int (*mlx_fft_ifft2_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s); -extern int (*mlx_fft_ifftn_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s); +extern int (*mlx_fft_ifft_ptr)(mlx_array* res, const mlx_array a, int n, int axis, mlx_fft_norm norm, const mlx_stream s); +extern int (*mlx_fft_ifft2_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, mlx_fft_norm norm, const mlx_stream s); +extern int (*mlx_fft_ifftn_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, mlx_fft_norm norm, const mlx_stream s); extern int (*mlx_fft_ifftshift_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s); -extern int (*mlx_fft_irfft_ptr)(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s); -extern int (*mlx_fft_irfft2_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s); -extern int (*mlx_fft_irfftn_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s); -extern int (*mlx_fft_rfft_ptr)(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s); -extern int (*mlx_fft_rfft2_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s); -extern int (*mlx_fft_rfftn_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s); +extern int (*mlx_fft_irfft_ptr)(mlx_array* res, const mlx_array a, int n, int axis, mlx_fft_norm norm, const mlx_stream s); +extern int (*mlx_fft_irfft2_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, mlx_fft_norm norm, const mlx_stream s); +extern int (*mlx_fft_irfftn_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, mlx_fft_norm norm, const mlx_stream s); +extern int (*mlx_fft_rfft_ptr)(mlx_array* res, const mlx_array a, int n, int axis, mlx_fft_norm norm, const mlx_stream s); +extern int (*mlx_fft_rfft2_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, mlx_fft_norm norm, const mlx_stream s); +extern int (*mlx_fft_rfftfreq_ptr)(mlx_array* res, int n, double d, const mlx_stream s); +extern int (*mlx_fft_rfftn_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, mlx_fft_norm norm, const mlx_stream s); +extern int (*mlx_node_namer_free_ptr)(mlx_node_namer namer); +extern int (*mlx_node_namer_set_name_ptr)(mlx_node_namer namer, const mlx_array arr, const char* name); +extern int (*mlx_node_namer_get_name_ptr)(const char** name, mlx_node_namer namer, const mlx_array arr); +extern int (*mlx_export_to_dot_ptr)(FILE* os, const mlx_node_namer namer, const mlx_vector_array outputs); +extern int (*mlx_print_graph_ptr)(FILE* os, const mlx_node_namer namer, const mlx_vector_array outputs); extern int (*mlx_load_reader_ptr)(mlx_array* res, mlx_io_reader in_stream, const mlx_stream s); extern int (*mlx_load_ptr)(mlx_array* res, const char* file, const mlx_stream s); +extern int (*mlx_load_gguf_ptr)(mlx_io_gguf* gguf, const char* file, const mlx_stream s); extern int (*mlx_load_safetensors_reader_ptr)(mlx_map_string_to_array* res_0, mlx_map_string_to_string* res_1, mlx_io_reader in_stream, const mlx_stream s); extern int (*mlx_load_safetensors_ptr)(mlx_map_string_to_array* res_0, mlx_map_string_to_string* res_1, const char* file, const mlx_stream s); extern int (*mlx_save_writer_ptr)(mlx_io_writer out_stream, const mlx_array a); extern int (*mlx_save_ptr)(const char* file, const mlx_array a); +extern int (*mlx_save_gguf_ptr)(const char* file, mlx_io_gguf gguf); extern int (*mlx_save_safetensors_writer_ptr)(mlx_io_writer in_stream, const mlx_map_string_to_array param, const mlx_map_string_to_string metadata); extern int (*mlx_save_safetensors_ptr)(const char* file, const mlx_map_string_to_array param, const mlx_map_string_to_string metadata); extern mlx_io_reader (*mlx_io_reader_new_ptr)(void* desc, mlx_io_vtable vtable); @@ -830,6 +870,20 @@ extern mlx_io_writer (*mlx_io_writer_new_ptr)(void* desc, mlx_io_vtable vtable); extern int (*mlx_io_writer_descriptor_ptr)(void** desc_, mlx_io_writer io); extern int (*mlx_io_writer_tostring_ptr)(mlx_string* str_, mlx_io_writer io); extern int (*mlx_io_writer_free_ptr)(mlx_io_writer io); +extern mlx_io_gguf (*mlx_io_gguf_new_ptr)(void); +extern int (*mlx_io_gguf_free_ptr)(mlx_io_gguf io); +extern int (*mlx_io_gguf_get_keys_ptr)(mlx_vector_string* keys, mlx_io_gguf io); +extern int (*mlx_io_gguf_get_array_ptr)(mlx_array* arr, mlx_io_gguf io, const char* key); +extern int (*mlx_io_gguf_get_metadata_array_ptr)(mlx_array* arr, mlx_io_gguf io, const char* key); +extern int (*mlx_io_gguf_get_metadata_string_ptr)(mlx_string* str, mlx_io_gguf io, const char* key); +extern int (*mlx_io_gguf_get_metadata_vector_string_ptr)(mlx_vector_string* vstr, mlx_io_gguf io, const char* key); +extern int (*mlx_io_gguf_has_metadata_array_ptr)(bool* flag, mlx_io_gguf io, const char* key); +extern int (*mlx_io_gguf_has_metadata_string_ptr)(bool* flag, mlx_io_gguf io, const char* key); +extern int (*mlx_io_gguf_has_metadata_vector_string_ptr)(bool* flag, mlx_io_gguf io, const char* key); +extern int (*mlx_io_gguf_set_array_ptr)(mlx_io_gguf io, const char* key, const mlx_array arr); +extern int (*mlx_io_gguf_set_metadata_array_ptr)(mlx_io_gguf io, const char* key, const mlx_array marr); +extern int (*mlx_io_gguf_set_metadata_string_ptr)(mlx_io_gguf io, const char* key, const char* mstr); +extern int (*mlx_io_gguf_set_metadata_vector_string_ptr)(mlx_io_gguf io, const char* key, const mlx_vector_string mvstr); extern int (*mlx_linalg_cholesky_ptr)(mlx_array* res, const mlx_array a, bool upper, const mlx_stream s); extern int (*mlx_linalg_cholesky_inv_ptr)(mlx_array* res, const mlx_array a, bool upper, const mlx_stream s); extern int (*mlx_linalg_cross_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int axis, const mlx_stream s); @@ -1064,6 +1118,10 @@ extern int (*mlx_slice_ptr)(mlx_array* res, const mlx_array a, const int* start, extern int (*mlx_slice_dynamic_ptr)(mlx_array* res, const mlx_array a, const mlx_array start, const int* axes, size_t axes_num, const int* slice_size, size_t slice_size_num, const mlx_stream s); extern int (*mlx_slice_update_ptr)(mlx_array* res, const mlx_array src, const mlx_array update, const int* start, size_t start_num, const int* stop, size_t stop_num, const int* strides, size_t strides_num, const mlx_stream s); extern int (*mlx_slice_update_dynamic_ptr)(mlx_array* res, const mlx_array src, const mlx_array update, const mlx_array start, const int* axes, size_t axes_num, const mlx_stream s); +extern int (*mlx_slice_update_add_ptr)(mlx_array* res, const mlx_array src, const mlx_array update, const int* start, size_t start_num, const int* stop, size_t stop_num, const int* strides, size_t strides_num, const mlx_stream s); +extern int (*mlx_slice_update_max_ptr)(mlx_array* res, const mlx_array src, const mlx_array update, const int* start, size_t start_num, const int* stop, size_t stop_num, const int* strides, size_t strides_num, const mlx_stream s); +extern int (*mlx_slice_update_min_ptr)(mlx_array* res, const mlx_array src, const mlx_array update, const int* start, size_t start_num, const int* stop, size_t stop_num, const int* strides, size_t strides_num, const mlx_stream s); +extern int (*mlx_slice_update_prod_ptr)(mlx_array* res, const mlx_array src, const mlx_array update, const int* start, size_t start_num, const int* stop, size_t stop_num, const int* strides, size_t strides_num, const mlx_stream s); extern int (*mlx_softmax_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool precise, const mlx_stream s); extern int (*mlx_softmax_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool precise, const mlx_stream s); extern int (*mlx_softmax_ptr)(mlx_array* res, const mlx_array a, bool precise, const mlx_stream s); @@ -1494,16 +1552,20 @@ int mlx_distributed_send(mlx_array* res, const mlx_array x, int dst, const mlx_d int mlx_distributed_sum_scatter(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s); +mlx_distributed_group mlx_distributed_group_new(void); + +int mlx_distributed_group_free(mlx_distributed_group group); + +int mlx_distributed_init(mlx_distributed_group* res, bool strict, const char* bk); + int mlx_distributed_group_rank(mlx_distributed_group group); int mlx_distributed_group_size(mlx_distributed_group group); -mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key); +int mlx_distributed_group_split(mlx_distributed_group* res, mlx_distributed_group group, int color, int key); bool mlx_distributed_is_available(const char* bk); -mlx_distributed_group mlx_distributed_init(bool strict, const char* bk); - void mlx_set_error_handler(mlx_error_handler_func handler, void* data, void (*dtor)(void*)); void _mlx_error(const char* file, const int line, const char* fmt, ...); @@ -1590,38 +1652,54 @@ int mlx_fast_rope_dynamic(mlx_array* res, const mlx_array x, int dims, bool trad int mlx_fast_scaled_dot_product_attention(mlx_array* res, const mlx_array queries, const mlx_array keys, const mlx_array values, float scale, const char* mask_mode, const mlx_array mask_arr , const mlx_array sinks , const mlx_stream s); -int mlx_fft_fft(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s); +int mlx_fft_fft(mlx_array* res, const mlx_array a, int n, int axis, mlx_fft_norm norm, const mlx_stream s); -int mlx_fft_fft2(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s); +int mlx_fft_fft2(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, mlx_fft_norm norm, const mlx_stream s); -int mlx_fft_fftn(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s); +int mlx_fft_fftfreq(mlx_array* res, int n, double d, const mlx_stream s); + +int mlx_fft_fftn(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, mlx_fft_norm norm, const mlx_stream s); int mlx_fft_fftshift(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s); -int mlx_fft_ifft(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s); +int mlx_fft_ifft(mlx_array* res, const mlx_array a, int n, int axis, mlx_fft_norm norm, const mlx_stream s); -int mlx_fft_ifft2(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s); +int mlx_fft_ifft2(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, mlx_fft_norm norm, const mlx_stream s); -int mlx_fft_ifftn(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s); +int mlx_fft_ifftn(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, mlx_fft_norm norm, const mlx_stream s); int mlx_fft_ifftshift(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s); -int mlx_fft_irfft(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s); +int mlx_fft_irfft(mlx_array* res, const mlx_array a, int n, int axis, mlx_fft_norm norm, const mlx_stream s); -int mlx_fft_irfft2(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s); +int mlx_fft_irfft2(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, mlx_fft_norm norm, const mlx_stream s); -int mlx_fft_irfftn(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s); +int mlx_fft_irfftn(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, mlx_fft_norm norm, const mlx_stream s); -int mlx_fft_rfft(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s); +int mlx_fft_rfft(mlx_array* res, const mlx_array a, int n, int axis, mlx_fft_norm norm, const mlx_stream s); -int mlx_fft_rfft2(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s); +int mlx_fft_rfft2(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, mlx_fft_norm norm, const mlx_stream s); -int mlx_fft_rfftn(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s); +int mlx_fft_rfftfreq(mlx_array* res, int n, double d, const mlx_stream s); + +int mlx_fft_rfftn(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, mlx_fft_norm norm, const mlx_stream s); + +int mlx_node_namer_free(mlx_node_namer namer); + +int mlx_node_namer_set_name(mlx_node_namer namer, const mlx_array arr, const char* name); + +int mlx_node_namer_get_name(const char** name, mlx_node_namer namer, const mlx_array arr); + +int mlx_export_to_dot(FILE* os, const mlx_node_namer namer, const mlx_vector_array outputs); + +int mlx_print_graph(FILE* os, const mlx_node_namer namer, const mlx_vector_array outputs); int mlx_load_reader(mlx_array* res, mlx_io_reader in_stream, const mlx_stream s); int mlx_load(mlx_array* res, const char* file, const mlx_stream s); +int mlx_load_gguf(mlx_io_gguf* gguf, const char* file, const mlx_stream s); + int mlx_load_safetensors_reader(mlx_map_string_to_array* res_0, mlx_map_string_to_string* res_1, mlx_io_reader in_stream, const mlx_stream s); int mlx_load_safetensors(mlx_map_string_to_array* res_0, mlx_map_string_to_string* res_1, const char* file, const mlx_stream s); @@ -1630,6 +1708,8 @@ int mlx_save_writer(mlx_io_writer out_stream, const mlx_array a); int mlx_save(const char* file, const mlx_array a); +int mlx_save_gguf(const char* file, mlx_io_gguf gguf); + int mlx_save_safetensors_writer(mlx_io_writer in_stream, const mlx_map_string_to_array param, const mlx_map_string_to_string metadata); int mlx_save_safetensors(const char* file, const mlx_map_string_to_array param, const mlx_map_string_to_string metadata); @@ -1650,6 +1730,34 @@ int mlx_io_writer_tostring(mlx_string* str_, mlx_io_writer io); int mlx_io_writer_free(mlx_io_writer io); +mlx_io_gguf mlx_io_gguf_new(void); + +int mlx_io_gguf_free(mlx_io_gguf io); + +int mlx_io_gguf_get_keys(mlx_vector_string* keys, mlx_io_gguf io); + +int mlx_io_gguf_get_array(mlx_array* arr, mlx_io_gguf io, const char* key); + +int mlx_io_gguf_get_metadata_array(mlx_array* arr, mlx_io_gguf io, const char* key); + +int mlx_io_gguf_get_metadata_string(mlx_string* str, mlx_io_gguf io, const char* key); + +int mlx_io_gguf_get_metadata_vector_string(mlx_vector_string* vstr, mlx_io_gguf io, const char* key); + +int mlx_io_gguf_has_metadata_array(bool* flag, mlx_io_gguf io, const char* key); + +int mlx_io_gguf_has_metadata_string(bool* flag, mlx_io_gguf io, const char* key); + +int mlx_io_gguf_has_metadata_vector_string(bool* flag, mlx_io_gguf io, const char* key); + +int mlx_io_gguf_set_array(mlx_io_gguf io, const char* key, const mlx_array arr); + +int mlx_io_gguf_set_metadata_array(mlx_io_gguf io, const char* key, const mlx_array marr); + +int mlx_io_gguf_set_metadata_string(mlx_io_gguf io, const char* key, const char* mstr); + +int mlx_io_gguf_set_metadata_vector_string(mlx_io_gguf io, const char* key, const mlx_vector_string mvstr); + int mlx_linalg_cholesky(mlx_array* res, const mlx_array a, bool upper, const mlx_stream s); int mlx_linalg_cholesky_inv(mlx_array* res, const mlx_array a, bool upper, const mlx_stream s); @@ -2118,6 +2226,14 @@ int mlx_slice_update(mlx_array* res, const mlx_array src, const mlx_array update int mlx_slice_update_dynamic(mlx_array* res, const mlx_array src, const mlx_array update, const mlx_array start, const int* axes, size_t axes_num, const mlx_stream s); +int mlx_slice_update_add(mlx_array* res, const mlx_array src, const mlx_array update, const int* start, size_t start_num, const int* stop, size_t stop_num, const int* strides, size_t strides_num, const mlx_stream s); + +int mlx_slice_update_max(mlx_array* res, const mlx_array src, const mlx_array update, const int* start, size_t start_num, const int* stop, size_t stop_num, const int* strides, size_t strides_num, const mlx_stream s); + +int mlx_slice_update_min(mlx_array* res, const mlx_array src, const mlx_array update, const int* start, size_t start_num, const int* stop, size_t stop_num, const int* strides, size_t strides_num, const mlx_stream s); + +int mlx_slice_update_prod(mlx_array* res, const mlx_array src, const mlx_array update, const int* start, size_t start_num, const int* stop, size_t stop_num, const int* strides, size_t strides_num, const mlx_stream s); + int mlx_softmax_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool precise, const mlx_stream s); int mlx_softmax_axis(mlx_array* res, const mlx_array a, int axis, bool precise, const mlx_stream s); diff --git a/x/internal/mlxthread/thread.go b/x/internal/mlxthread/thread.go new file mode 100644 index 000000000..5c0419318 --- /dev/null +++ b/x/internal/mlxthread/thread.go @@ -0,0 +1,183 @@ +package mlxthread + +import ( + "context" + "errors" + "runtime" + "sync/atomic" +) + +var ErrStopped = errors.New("mlx thread stopped") + +type Thread struct { + name string + + jobs chan job + done chan struct{} + stopping atomic.Bool +} + +type job struct { + fn func() error + result chan result + stop bool +} + +type result struct { + err error + panicValue any +} + +// Start creates a long-lived worker goroutine locked to one OS thread. +func Start(name string, init func() error) (*Thread, error) { + t := &Thread{ + name: name, + jobs: make(chan job), + done: make(chan struct{}), + } + + initResult := make(chan result, 1) + go t.loop(init, initResult) + + res := <-initResult + if res.panicValue != nil { + panic(res.panicValue) + } + if res.err != nil { + return nil, res.err + } + + return t, nil +} + +// Do runs fn on the locked OS thread. +// +// Context cancellation only applies while the work is queued. Once the worker +// accepts a job, the job runs until fn returns or reaches its own cancellation +// checks. +func (t *Thread) Do(ctx context.Context, fn func() error) error { + res, err := t.enqueue(ctx, fn, false, false) + if err != nil { + return err + } + if res.panicValue != nil { + panic(res.panicValue) + } + return res.err +} + +func Call[T any](ctx context.Context, t *Thread, fn func() (T, error)) (T, error) { + var value T + err := t.Do(ctx, func() error { + var err error + value, err = fn() + return err + }) + return value, err +} + +// Stop runs cleanup on the locked OS thread and then shuts the worker down. +func (t *Thread) Stop(ctx context.Context, cleanup func()) error { + ctx = contextOrBackground(ctx) + + if !t.stopping.CompareAndSwap(false, true) { + select { + case <-t.done: + return nil + case <-ctx.Done(): + return ctx.Err() + } + } + + res, err := t.enqueue(ctx, func() error { + if cleanup != nil { + cleanup() + } + return nil + }, true, true) + if err != nil { + if !errors.Is(err, ErrStopped) { + t.stopping.Store(false) + } + return err + } + if res.panicValue != nil { + panic(res.panicValue) + } + if res.err != nil { + return res.err + } + + select { + case <-t.done: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func (t *Thread) loop(init func() error, initResult chan<- result) { + runtime.LockOSThread() + // Deliberately do not unlock. MLX thread-local state belongs to this worker + // until shutdown so it cannot leak back to arbitrary Go goroutines. + + res := run(init) + initResult <- res + if res.err != nil || res.panicValue != nil { + close(t.done) + return + } + + for { + j := <-t.jobs + res := run(j.fn) + j.result <- res + if j.stop { + close(t.done) + return + } + } +} + +func (t *Thread) enqueue(ctx context.Context, fn func() error, stop, allowStopping bool) (result, error) { + ctx = contextOrBackground(ctx) + if err := ctx.Err(); err != nil { + return result{}, err + } + + if !allowStopping && t.stopping.Load() { + return result{}, ErrStopped + } + + resultCh := make(chan result, 1) + j := job{fn: fn, result: resultCh, stop: stop} + + select { + case <-ctx.Done(): + return result{}, ctx.Err() + case <-t.done: + return result{}, ErrStopped + case t.jobs <- j: + } + + return <-resultCh, nil +} + +func run(fn func() error) (res result) { + defer func() { + if v := recover(); v != nil { + res.panicValue = v + } + }() + if fn != nil { + res.err = fn() + } + return res +} + +func contextOrBackground(ctx context.Context) context.Context { + if ctx != nil { + return ctx + } + return context.Background() +} diff --git a/x/internal/mlxthread/thread_affinity_test.go b/x/internal/mlxthread/thread_affinity_test.go new file mode 100644 index 000000000..a8c1a0a54 --- /dev/null +++ b/x/internal/mlxthread/thread_affinity_test.go @@ -0,0 +1,32 @@ +//go:build darwin || linux + +package mlxthread + +import ( + "context" + "fmt" + "testing" +) + +func TestDoUsesSameOSThread(t *testing.T) { + thread, err := Start("test", nil) + if err != nil { + t.Fatal(err) + } + defer thread.Stop(context.Background(), nil) + + var first uint64 + for range 32 { + if err := thread.Do(context.Background(), func() error { + id := currentThreadID() + if first == 0 { + first = id + } else if id != first { + return fmt.Errorf("job ran on OS thread %d, want %d", id, first) + } + return nil + }); err != nil { + t.Fatal(err) + } + } +} diff --git a/x/internal/mlxthread/thread_test.go b/x/internal/mlxthread/thread_test.go new file mode 100644 index 000000000..9d511351b --- /dev/null +++ b/x/internal/mlxthread/thread_test.go @@ -0,0 +1,351 @@ +package mlxthread + +import ( + "context" + "errors" + "reflect" + "runtime" + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestDoRunsInOrder(t *testing.T) { + thread, err := Start("test", nil) + if err != nil { + t.Fatal(err) + } + defer thread.Stop(context.Background(), nil) + + var got []int + for i := 0; i < 5; i++ { + i := i + if err := thread.Do(context.Background(), func() error { + got = append(got, i) + return nil + }); err != nil { + t.Fatal(err) + } + } + + if want := []int{0, 1, 2, 3, 4}; !reflect.DeepEqual(got, want) { + t.Fatalf("got %v, want %v", got, want) + } +} + +func TestDoPropagatesPanicToCaller(t *testing.T) { + thread, err := Start("test", nil) + if err != nil { + t.Fatal(err) + } + defer thread.Stop(context.Background(), nil) + + defer func() { + if got := recover(); got != "boom" { + t.Fatalf("got panic %v, want boom", got) + } + }() + + _ = thread.Do(context.Background(), func() error { + panic("boom") + }) +} + +func TestDoCancelsBeforeJobStarts(t *testing.T) { + thread, err := Start("test", nil) + if err != nil { + t.Fatal(err) + } + defer thread.Stop(context.Background(), nil) + + running := make(chan struct{}) + release := make(chan struct{}) + errCh := make(chan error, 1) + go func() { + errCh <- thread.Do(context.Background(), func() error { + close(running) + <-release + return nil + }) + }() + <-running + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err = thread.Do(ctx, func() error { + t.Fatal("canceled job should not run") + return nil + }) + if !errors.Is(err, context.Canceled) { + t.Fatalf("got %v, want %v", err, context.Canceled) + } + + close(release) + if err := <-errCh; err != nil { + t.Fatal(err) + } +} + +func TestAlreadyCanceledContextDoesNotEnqueue(t *testing.T) { + t.Run("Do", func(t *testing.T) { + thread, err := Start("test", nil) + if err != nil { + t.Fatal(err) + } + defer thread.Stop(context.Background(), nil) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + ran := false + err = thread.Do(ctx, func() error { + ran = true + return nil + }) + if !errors.Is(err, context.Canceled) { + t.Fatalf("got %v, want %v", err, context.Canceled) + } + if ran { + t.Fatal("canceled job ran") + } + }) + + t.Run("Stop", func(t *testing.T) { + thread, err := Start("test", nil) + if err != nil { + t.Fatal(err) + } + defer thread.Stop(context.Background(), nil) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + cleaned := false + err = thread.Stop(ctx, func() { + cleaned = true + }) + if !errors.Is(err, context.Canceled) { + t.Fatalf("got %v, want %v", err, context.Canceled) + } + if cleaned { + t.Fatal("cleanup ran for canceled stop") + } + if err := thread.Do(context.Background(), func() error { return nil }); err != nil { + t.Fatalf("thread did not accept work after canceled Stop: %v", err) + } + }) +} + +func TestCallReturnsValue(t *testing.T) { + thread, err := Start("test", nil) + if err != nil { + t.Fatal(err) + } + defer thread.Stop(context.Background(), nil) + + got, err := Call(context.Background(), thread, func() (int, error) { + return 42, nil + }) + if err != nil { + t.Fatal(err) + } + if got != 42 { + t.Fatalf("got %d, want 42", got) + } +} + +func TestDoRunsConcurrentlySubmittedWorkSerially(t *testing.T) { + thread, err := Start("test", nil) + if err != nil { + t.Fatal(err) + } + defer thread.Stop(context.Background(), nil) + + oldProcs := runtime.GOMAXPROCS(8) + defer runtime.GOMAXPROCS(oldProcs) + + const goroutines = 16 + const iterations = 64 + + var active atomic.Int32 + var count atomic.Int64 + var wg sync.WaitGroup + errCh := make(chan error, goroutines) + + for range goroutines { + wg.Add(1) + go func() { + defer wg.Done() + + for range iterations { + if err := thread.Do(context.Background(), func() error { + if got := active.Add(1); got != 1 { + return errors.New("thread executed jobs concurrently") + } + runtime.Gosched() + count.Add(1) + if got := active.Add(-1); got != 0 { + return errors.New("thread active count did not return to zero") + } + return nil + }); err != nil { + errCh <- err + return + } + } + }() + } + + wg.Wait() + close(errCh) + + for err := range errCh { + t.Fatal(err) + } + if got, want := count.Load(), int64(goroutines*iterations); got != want { + t.Fatalf("got %d jobs, want %d", got, want) + } +} + +func TestStopRunsCleanupAndRejectsWork(t *testing.T) { + thread, err := Start("test", nil) + if err != nil { + t.Fatal(err) + } + + cleaned := 0 + if err := thread.Stop(context.Background(), func() { + cleaned++ + }); err != nil { + t.Fatal(err) + } + if cleaned != 1 { + t.Fatalf("cleanup ran %d times, want 1", cleaned) + } + + if err := thread.Stop(context.Background(), func() { + cleaned++ + }); err != nil { + t.Fatal(err) + } + if cleaned != 1 { + t.Fatalf("cleanup ran %d times after second Stop, want 1", cleaned) + } + + err = thread.Do(context.Background(), func() error { + t.Fatal("job should not run after stop") + return nil + }) + if !errors.Is(err, ErrStopped) { + t.Fatalf("got %v, want %v", err, ErrStopped) + } +} + +func TestStopCanceledBeforeEnqueueCanBeRetried(t *testing.T) { + thread, err := Start("test", nil) + if err != nil { + t.Fatal(err) + } + defer thread.Stop(context.Background(), nil) + + running := make(chan struct{}) + release := make(chan struct{}) + errCh := make(chan error, 1) + go func() { + errCh <- thread.Do(context.Background(), func() error { + close(running) + <-release + return nil + }) + }() + <-running + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + cleanupRan := false + err = thread.Stop(ctx, func() { + cleanupRan = true + }) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("got %v, want %v", err, context.DeadlineExceeded) + } + if cleanupRan { + t.Fatal("cleanup ran even though stop was not enqueued") + } + + close(release) + if err := <-errCh; err != nil { + t.Fatal(err) + } + + if err := thread.Do(context.Background(), func() error { return nil }); err != nil { + t.Fatalf("thread did not accept work after canceled Stop: %v", err) + } + + cleanupRan = false + if err := thread.Stop(context.Background(), func() { + cleanupRan = true + }); err != nil { + t.Fatal(err) + } + if !cleanupRan { + t.Fatal("cleanup did not run on retried Stop") + } +} + +func TestStopWaitsForActiveWorkBeforeCleanup(t *testing.T) { + thread, err := Start("test", nil) + if err != nil { + t.Fatal(err) + } + + running := make(chan struct{}) + release := make(chan struct{}) + jobErr := make(chan error, 1) + go func() { + jobErr <- thread.Do(context.Background(), func() error { + close(running) + <-release + return nil + }) + }() + <-running + + cleaned := make(chan struct{}) + stopErr := make(chan error, 1) + go func() { + stopErr <- thread.Stop(context.Background(), func() { + close(cleaned) + }) + }() + + select { + case <-cleaned: + t.Fatal("cleanup ran before active job completed") + case <-time.After(10 * time.Millisecond): + } + + err = thread.Do(context.Background(), func() error { + return errors.New("work should be rejected once Stop starts") + }) + if !errors.Is(err, ErrStopped) { + t.Fatalf("got %v, want %v", err, ErrStopped) + } + + close(release) + if err := <-jobErr; err != nil { + t.Fatal(err) + } + if err := <-stopErr; err != nil { + t.Fatal(err) + } + + select { + case <-cleaned: + default: + t.Fatal("cleanup did not run") + } +} diff --git a/x/internal/mlxthread/threadid_darwin_test.go b/x/internal/mlxthread/threadid_darwin_test.go new file mode 100644 index 000000000..c7240d138 --- /dev/null +++ b/x/internal/mlxthread/threadid_darwin_test.go @@ -0,0 +1,10 @@ +//go:build darwin + +package mlxthread + +import "syscall" + +func currentThreadID() uint64 { + id, _, _ := syscall.RawSyscall(syscall.SYS_THREAD_SELFID, 0, 0, 0) + return uint64(id) +} diff --git a/x/internal/mlxthread/threadid_linux_test.go b/x/internal/mlxthread/threadid_linux_test.go new file mode 100644 index 000000000..4eeb53502 --- /dev/null +++ b/x/internal/mlxthread/threadid_linux_test.go @@ -0,0 +1,9 @@ +//go:build linux + +package mlxthread + +import "syscall" + +func currentThreadID() uint64 { + return uint64(syscall.Gettid()) +} diff --git a/x/mlxrunner/client.go b/x/mlxrunner/client.go index e2774dced..dc2ce44df 100644 --- a/x/mlxrunner/client.go +++ b/x/mlxrunner/client.go @@ -2,7 +2,6 @@ package mlxrunner import ( "bufio" - "bytes" "context" "encoding/json" "errors" @@ -40,66 +39,11 @@ type Client struct { done chan struct{} doneErr error // valid after done is closed client *http.Client - status *statusWriter + status *llm.StatusWriter mu sync.Mutex cmd *exec.Cmd } -// statusWriter captures the last subprocess line while forwarding all output -// to os.Stderr. Lines longer than maxStatusLen are truncated to the first -// maxStatusLen bytes. -type statusWriter struct { - lastErrMsg string - buf []byte - discarding bool - mu sync.Mutex - out *os.File -} - -const maxStatusLen = 256 - -func (w *statusWriter) Write(b []byte) (int, error) { - n, err := w.out.Write(b) - - w.mu.Lock() - defer w.mu.Unlock() - - w.buf = append(w.buf, b...) - for { - i := bytes.IndexByte(w.buf, '\n') - if i < 0 { - break - } - if !w.discarding { - line := bytes.TrimSpace(w.buf[:i]) - if len(line) > 0 { - if len(line) > maxStatusLen { - line = line[:maxStatusLen] - } - w.lastErrMsg = string(line) - } - } - w.buf = w.buf[i+1:] - w.discarding = false - } - // if the buffer grows past maxStatusLen without a newline, keep the front - if len(w.buf) > maxStatusLen { - if !w.discarding { - w.lastErrMsg = string(bytes.TrimSpace(w.buf[:maxStatusLen])) - w.discarding = true - } - w.buf = w.buf[:0] - } - - return n, err -} - -func (w *statusWriter) getLastErr() string { - w.mu.Lock() - defer w.mu.Unlock() - return w.lastErrMsg -} - // NewClient prepares a new MLX runner client for LLM models. // The subprocess is not started until Load() is called. func NewClient(modelName string) (*Client, error) { @@ -133,12 +77,12 @@ func (c *Client) WaitUntilRunning(ctx context.Context) error { case <-ctx.Done(): return ctx.Err() case <-c.done: - if msg := c.status.getLastErr(); msg != "" { + if msg := c.status.LastError(); msg != "" { return fmt.Errorf("mlx runner failed: %s (exit: %v)", msg, c.doneErr) } return fmt.Errorf("mlx runner exited unexpectedly: %w", c.doneErr) case <-timeout: - if msg := c.status.getLastErr(); msg != "" { + if msg := c.status.LastError(); msg != "" { return fmt.Errorf("timeout waiting for mlx runner: %s", msg) } return errors.New("timeout waiting for mlx runner to start") @@ -217,7 +161,7 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f resp, err := c.client.Do(httpReq) if err != nil { - if errMsg := c.status.getLastErr(); errMsg != "" { + if errMsg := c.status.LastError(); errMsg != "" { return fmt.Errorf("mlx runner failed: %s", errMsg) } return err @@ -259,7 +203,7 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f } if err := scanner.Err(); err != nil { - if errMsg := c.status.getLastErr(); errMsg != "" { + if errMsg := c.status.LastError(); errMsg != "" { return fmt.Errorf("mlx runner failed: %s", errMsg) } return err @@ -405,7 +349,7 @@ func (c *Client) Load(ctx context.Context, _ ml.SystemInfo, gpus []ml.DeviceInfo c.cmd = cmd - status := &statusWriter{out: os.Stderr} + status := llm.NewStatusWriter(os.Stderr) c.status = status // os/exec serializes Write calls when shared, which keeps the status writer // from seeing concurrent stdout/stderr fragments. diff --git a/x/mlxrunner/mlx/array_test.go b/x/mlxrunner/mlx/array_test.go index bc6a4ca4a..375e674d9 100644 --- a/x/mlxrunner/mlx/array_test.go +++ b/x/mlxrunner/mlx/array_test.go @@ -2,51 +2,42 @@ package mlx import "testing" -func skipIfNoMLX(t *testing.T) { - t.Helper() - if err := CheckInit(); err != nil { - t.Skipf("MLX not available: %v", err) - } -} - func TestFromValue(t *testing.T) { - skipIfNoMLX(t) - for got, want := range map[*Array]DType{ - FromValue(true): DTypeBool, - FromValue(false): DTypeBool, - FromValue(int(7)): DTypeInt32, - FromValue(float32(3.14)): DTypeFloat32, - FromValue(float64(2.71)): DTypeFloat64, - FromValue(complex64(1 + 2i)): DTypeComplex64, - } { - t.Run(want.String(), func(t *testing.T) { + withMLXThread(t, func() { + for got, want := range map[*Array]DType{ + FromValue(true): DTypeBool, + FromValue(false): DTypeBool, + FromValue(int(7)): DTypeInt32, + FromValue(float32(3.14)): DTypeFloat32, + FromValue(float64(2.71)): DTypeFloat64, + FromValue(complex64(1 + 2i)): DTypeComplex64, + } { if got.DType() != want { - t.Errorf("want %v, got %v", want, got) + t.Errorf("%s: want %v, got %v", want, want, got) } - }) - } + } + }) } func TestFromValues(t *testing.T) { - skipIfNoMLX(t) - for got, want := range map[*Array]DType{ - FromValues([]bool{true, false, true}, 3): DTypeBool, - FromValues([]uint8{1, 2, 3}, 3): DTypeUint8, - FromValues([]uint16{1, 2, 3}, 3): DTypeUint16, - FromValues([]uint32{1, 2, 3}, 3): DTypeUint32, - FromValues([]uint64{1, 2, 3}, 3): DTypeUint64, - FromValues([]int8{-1, -2, -3}, 3): DTypeInt8, - FromValues([]int16{-1, -2, -3}, 3): DTypeInt16, - FromValues([]int32{-1, -2, -3}, 3): DTypeInt32, - FromValues([]int64{-1, -2, -3}, 3): DTypeInt64, - FromValues([]float32{3.14, 2.71, 1.61}, 3): DTypeFloat32, - FromValues([]float64{3.14, 2.71, 1.61}, 3): DTypeFloat64, - FromValues([]complex64{1 + 2i, 3 + 4i, 5 + 6i}, 3): DTypeComplex64, - } { - t.Run(want.String(), func(t *testing.T) { + withMLXThread(t, func() { + for got, want := range map[*Array]DType{ + FromValues([]bool{true, false, true}, 3): DTypeBool, + FromValues([]uint8{1, 2, 3}, 3): DTypeUint8, + FromValues([]uint16{1, 2, 3}, 3): DTypeUint16, + FromValues([]uint32{1, 2, 3}, 3): DTypeUint32, + FromValues([]uint64{1, 2, 3}, 3): DTypeUint64, + FromValues([]int8{-1, -2, -3}, 3): DTypeInt8, + FromValues([]int16{-1, -2, -3}, 3): DTypeInt16, + FromValues([]int32{-1, -2, -3}, 3): DTypeInt32, + FromValues([]int64{-1, -2, -3}, 3): DTypeInt64, + FromValues([]float32{3.14, 2.71, 1.61}, 3): DTypeFloat32, + FromValues([]float64{3.14, 2.71, 1.61}, 3): DTypeFloat64, + FromValues([]complex64{1 + 2i, 3 + 4i, 5 + 6i}, 3): DTypeComplex64, + } { if got.DType() != want { - t.Errorf("want %v, got %v", want, got) + t.Errorf("%s: want %v, got %v", want, want, got) } - }) - } + } + }) } diff --git a/x/mlxrunner/mlx/generated.c b/x/mlxrunner/mlx/generated.c index f1333971c..8147fe5b4 100644 --- a/x/mlxrunner/mlx/generated.c +++ b/x/mlxrunner/mlx/generated.c @@ -323,13 +323,20 @@ int (*mlx_distributed_sum_scatter_)( const mlx_array x, const mlx_distributed_group group /* may be null */, const mlx_stream s) = NULL; -int (*mlx_distributed_group_rank_)(mlx_distributed_group group) = NULL; -int (*mlx_distributed_group_size_)(mlx_distributed_group group) = NULL; -mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key) = NULL; -bool (*mlx_distributed_is_available_)(const char* bk /* may be null */) = NULL; -mlx_distributed_group (*mlx_distributed_init_)( +mlx_distributed_group (*mlx_distributed_group_new_)(void) = NULL; +int (*mlx_distributed_group_free_)(mlx_distributed_group group) = NULL; +int (*mlx_distributed_init_)( + mlx_distributed_group* res, bool strict, const char* bk /* may be null */) = NULL; +int (*mlx_distributed_group_rank_)(mlx_distributed_group group) = NULL; +int (*mlx_distributed_group_size_)(mlx_distributed_group group) = NULL; +int (*mlx_distributed_group_split_)( + mlx_distributed_group* res, + mlx_distributed_group group, + int color, + int key) = NULL; +bool (*mlx_distributed_is_available_)(const char* bk /* may be null */) = NULL; void (*mlx_set_error_handler_)( mlx_error_handler_func handler, void* data, @@ -517,6 +524,7 @@ int (*mlx_fft_fft_)( const mlx_array a, int n, int axis, + mlx_fft_norm norm, const mlx_stream s) = NULL; int (*mlx_fft_fft2_)( mlx_array* res, @@ -525,7 +533,9 @@ int (*mlx_fft_fft2_)( size_t n_num, const int* axes, size_t axes_num, + mlx_fft_norm norm, const mlx_stream s) = NULL; +int (*mlx_fft_fftfreq_)(mlx_array* res, int n, double d, const mlx_stream s) = NULL; int (*mlx_fft_fftn_)( mlx_array* res, const mlx_array a, @@ -533,6 +543,7 @@ int (*mlx_fft_fftn_)( size_t n_num, const int* axes, size_t axes_num, + mlx_fft_norm norm, const mlx_stream s) = NULL; int (*mlx_fft_fftshift_)( mlx_array* res, @@ -545,6 +556,7 @@ int (*mlx_fft_ifft_)( const mlx_array a, int n, int axis, + mlx_fft_norm norm, const mlx_stream s) = NULL; int (*mlx_fft_ifft2_)( mlx_array* res, @@ -553,6 +565,7 @@ int (*mlx_fft_ifft2_)( size_t n_num, const int* axes, size_t axes_num, + mlx_fft_norm norm, const mlx_stream s) = NULL; int (*mlx_fft_ifftn_)( mlx_array* res, @@ -561,6 +574,7 @@ int (*mlx_fft_ifftn_)( size_t n_num, const int* axes, size_t axes_num, + mlx_fft_norm norm, const mlx_stream s) = NULL; int (*mlx_fft_ifftshift_)( mlx_array* res, @@ -573,6 +587,7 @@ int (*mlx_fft_irfft_)( const mlx_array a, int n, int axis, + mlx_fft_norm norm, const mlx_stream s) = NULL; int (*mlx_fft_irfft2_)( mlx_array* res, @@ -581,6 +596,7 @@ int (*mlx_fft_irfft2_)( size_t n_num, const int* axes, size_t axes_num, + mlx_fft_norm norm, const mlx_stream s) = NULL; int (*mlx_fft_irfftn_)( mlx_array* res, @@ -589,12 +605,14 @@ int (*mlx_fft_irfftn_)( size_t n_num, const int* axes, size_t axes_num, + mlx_fft_norm norm, const mlx_stream s) = NULL; int (*mlx_fft_rfft_)( mlx_array* res, const mlx_array a, int n, int axis, + mlx_fft_norm norm, const mlx_stream s) = NULL; int (*mlx_fft_rfft2_)( mlx_array* res, @@ -603,7 +621,9 @@ int (*mlx_fft_rfft2_)( size_t n_num, const int* axes, size_t axes_num, + mlx_fft_norm norm, const mlx_stream s) = NULL; +int (*mlx_fft_rfftfreq_)(mlx_array* res, int n, double d, const mlx_stream s) = NULL; int (*mlx_fft_rfftn_)( mlx_array* res, const mlx_array a, @@ -611,12 +631,32 @@ int (*mlx_fft_rfftn_)( size_t n_num, const int* axes, size_t axes_num, + mlx_fft_norm norm, const mlx_stream s) = NULL; +mlx_node_namer (*mlx_node_namer_new_)() = NULL; +int (*mlx_node_namer_free_)(mlx_node_namer namer) = NULL; +int (*mlx_node_namer_set_name_)( + mlx_node_namer namer, + const mlx_array arr, + const char* name) = NULL; +int (*mlx_node_namer_get_name_)( + const char** name, + mlx_node_namer namer, + const mlx_array arr) = NULL; +int (*mlx_export_to_dot_)( + FILE* os, + const mlx_node_namer namer, + const mlx_vector_array outputs) = NULL; +int (*mlx_print_graph_)( + FILE* os, + const mlx_node_namer namer, + const mlx_vector_array outputs) = NULL; int (*mlx_load_reader_)( mlx_array* res, mlx_io_reader in_stream, const mlx_stream s) = NULL; int (*mlx_load_)(mlx_array* res, const char* file, const mlx_stream s) = NULL; +int (*mlx_load_gguf_)(mlx_io_gguf* gguf, const char* file, const mlx_stream s) = NULL; int (*mlx_load_safetensors_reader_)( mlx_map_string_to_array* res_0, mlx_map_string_to_string* res_1, @@ -629,6 +669,7 @@ int (*mlx_load_safetensors_)( const mlx_stream s) = NULL; int (*mlx_save_writer_)(mlx_io_writer out_stream, const mlx_array a) = NULL; int (*mlx_save_)(const char* file, const mlx_array a) = NULL; +int (*mlx_save_gguf_)(const char* file, mlx_io_gguf gguf) = NULL; int (*mlx_save_safetensors_writer_)( mlx_io_writer in_stream, const mlx_map_string_to_array param, @@ -645,6 +686,44 @@ mlx_io_writer (*mlx_io_writer_new_)(void* desc, mlx_io_vtable vtable) = NULL; int (*mlx_io_writer_descriptor_)(void** desc_, mlx_io_writer io) = NULL; int (*mlx_io_writer_tostring_)(mlx_string* str_, mlx_io_writer io) = NULL; int (*mlx_io_writer_free_)(mlx_io_writer io) = NULL; +mlx_io_gguf (*mlx_io_gguf_new_)(void) = NULL; +int (*mlx_io_gguf_free_)(mlx_io_gguf io) = NULL; +int (*mlx_io_gguf_get_keys_)(mlx_vector_string* keys, mlx_io_gguf io) = NULL; +int (*mlx_io_gguf_get_array_)(mlx_array* arr, mlx_io_gguf io, const char* key) = NULL; +int (*mlx_io_gguf_get_metadata_array_)( + mlx_array* arr, + mlx_io_gguf io, + const char* key) = NULL; +int (*mlx_io_gguf_get_metadata_string_)( + mlx_string* str, + mlx_io_gguf io, + const char* key) = NULL; +int (*mlx_io_gguf_get_metadata_vector_string_)( + mlx_vector_string* vstr, + mlx_io_gguf io, + const char* key) = NULL; +int (*mlx_io_gguf_has_metadata_array_)(bool* flag, mlx_io_gguf io, const char* key) = NULL; +int (*mlx_io_gguf_has_metadata_string_)( + bool* flag, + mlx_io_gguf io, + const char* key) = NULL; +int (*mlx_io_gguf_has_metadata_vector_string_)( + bool* flag, + mlx_io_gguf io, + const char* key) = NULL; +int (*mlx_io_gguf_set_array_)(mlx_io_gguf io, const char* key, const mlx_array arr) = NULL; +int (*mlx_io_gguf_set_metadata_array_)( + mlx_io_gguf io, + const char* key, + const mlx_array marr) = NULL; +int (*mlx_io_gguf_set_metadata_string_)( + mlx_io_gguf io, + const char* key, + const char* mstr) = NULL; +int (*mlx_io_gguf_set_metadata_vector_string_)( + mlx_io_gguf io, + const char* key, + const mlx_vector_string mvstr) = NULL; int (*mlx_linalg_cholesky_)( mlx_array* res, const mlx_array a, @@ -1764,6 +1843,50 @@ int (*mlx_slice_update_dynamic_)( const int* axes, size_t axes_num, const mlx_stream s) = NULL; +int (*mlx_slice_update_add_)( + mlx_array* res, + const mlx_array src, + const mlx_array update, + const int* start, + size_t start_num, + const int* stop, + size_t stop_num, + const int* strides, + size_t strides_num, + const mlx_stream s) = NULL; +int (*mlx_slice_update_max_)( + mlx_array* res, + const mlx_array src, + const mlx_array update, + const int* start, + size_t start_num, + const int* stop, + size_t stop_num, + const int* strides, + size_t strides_num, + const mlx_stream s) = NULL; +int (*mlx_slice_update_min_)( + mlx_array* res, + const mlx_array src, + const mlx_array update, + const int* start, + size_t start_num, + const int* stop, + size_t stop_num, + const int* strides, + size_t strides_num, + const mlx_stream s) = NULL; +int (*mlx_slice_update_prod_)( + mlx_array* res, + const mlx_array src, + const mlx_array update, + const int* start, + size_t start_num, + const int* stop, + size_t stop_num, + const int* strides, + size_t strides_num, + const mlx_stream s) = NULL; int (*mlx_softmax_axes_)( mlx_array* res, const mlx_array a, @@ -2403,11 +2526,13 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) { CHECK_LOAD(handle, mlx_distributed_recv_like); CHECK_LOAD(handle, mlx_distributed_send); CHECK_LOAD(handle, mlx_distributed_sum_scatter); + CHECK_LOAD(handle, mlx_distributed_group_new); + CHECK_LOAD(handle, mlx_distributed_group_free); + CHECK_LOAD(handle, mlx_distributed_init); CHECK_LOAD(handle, mlx_distributed_group_rank); CHECK_LOAD(handle, mlx_distributed_group_size); CHECK_LOAD(handle, mlx_distributed_group_split); CHECK_LOAD(handle, mlx_distributed_is_available); - CHECK_LOAD(handle, mlx_distributed_init); CHECK_LOAD(handle, mlx_set_error_handler); CHECK_LOAD(handle, _mlx_error); CHECK_LOAD(handle, mlx_export_function); @@ -2453,6 +2578,7 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) { CHECK_LOAD(handle, mlx_fast_scaled_dot_product_attention); CHECK_LOAD(handle, mlx_fft_fft); CHECK_LOAD(handle, mlx_fft_fft2); + CHECK_LOAD(handle, mlx_fft_fftfreq); CHECK_LOAD(handle, mlx_fft_fftn); CHECK_LOAD(handle, mlx_fft_fftshift); CHECK_LOAD(handle, mlx_fft_ifft); @@ -2464,13 +2590,22 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) { CHECK_LOAD(handle, mlx_fft_irfftn); CHECK_LOAD(handle, mlx_fft_rfft); CHECK_LOAD(handle, mlx_fft_rfft2); + CHECK_LOAD(handle, mlx_fft_rfftfreq); CHECK_LOAD(handle, mlx_fft_rfftn); + CHECK_LOAD(handle, mlx_node_namer_new); + CHECK_LOAD(handle, mlx_node_namer_free); + CHECK_LOAD(handle, mlx_node_namer_set_name); + CHECK_LOAD(handle, mlx_node_namer_get_name); + CHECK_LOAD(handle, mlx_export_to_dot); + CHECK_LOAD(handle, mlx_print_graph); CHECK_LOAD(handle, mlx_load_reader); CHECK_LOAD(handle, mlx_load); + CHECK_LOAD(handle, mlx_load_gguf); CHECK_LOAD(handle, mlx_load_safetensors_reader); CHECK_LOAD(handle, mlx_load_safetensors); CHECK_LOAD(handle, mlx_save_writer); CHECK_LOAD(handle, mlx_save); + CHECK_LOAD(handle, mlx_save_gguf); CHECK_LOAD(handle, mlx_save_safetensors_writer); CHECK_LOAD(handle, mlx_save_safetensors); CHECK_LOAD(handle, mlx_io_reader_new); @@ -2481,6 +2616,20 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) { CHECK_LOAD(handle, mlx_io_writer_descriptor); CHECK_LOAD(handle, mlx_io_writer_tostring); CHECK_LOAD(handle, mlx_io_writer_free); + CHECK_LOAD(handle, mlx_io_gguf_new); + CHECK_LOAD(handle, mlx_io_gguf_free); + CHECK_LOAD(handle, mlx_io_gguf_get_keys); + CHECK_LOAD(handle, mlx_io_gguf_get_array); + CHECK_LOAD(handle, mlx_io_gguf_get_metadata_array); + CHECK_LOAD(handle, mlx_io_gguf_get_metadata_string); + CHECK_LOAD(handle, mlx_io_gguf_get_metadata_vector_string); + CHECK_LOAD(handle, mlx_io_gguf_has_metadata_array); + CHECK_LOAD(handle, mlx_io_gguf_has_metadata_string); + CHECK_LOAD(handle, mlx_io_gguf_has_metadata_vector_string); + CHECK_LOAD(handle, mlx_io_gguf_set_array); + CHECK_LOAD(handle, mlx_io_gguf_set_metadata_array); + CHECK_LOAD(handle, mlx_io_gguf_set_metadata_string); + CHECK_LOAD(handle, mlx_io_gguf_set_metadata_vector_string); CHECK_LOAD(handle, mlx_linalg_cholesky); CHECK_LOAD(handle, mlx_linalg_cholesky_inv); CHECK_LOAD(handle, mlx_linalg_cross); @@ -2715,6 +2864,10 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) { CHECK_LOAD(handle, mlx_slice_dynamic); CHECK_LOAD(handle, mlx_slice_update); CHECK_LOAD(handle, mlx_slice_update_dynamic); + CHECK_LOAD(handle, mlx_slice_update_add); + CHECK_LOAD(handle, mlx_slice_update_max); + CHECK_LOAD(handle, mlx_slice_update_min); + CHECK_LOAD(handle, mlx_slice_update_prod); CHECK_LOAD(handle, mlx_softmax_axes); CHECK_LOAD(handle, mlx_softmax_axis); CHECK_LOAD(handle, mlx_softmax); diff --git a/x/mlxrunner/mlx/generated.h b/x/mlxrunner/mlx/generated.h index 26119f2ff..02beed9d0 100644 --- a/x/mlxrunner/mlx/generated.h +++ b/x/mlxrunner/mlx/generated.h @@ -143,11 +143,13 @@ #define mlx_distributed_recv_like mlx_distributed_recv_like_mlx_gen_orig_ #define mlx_distributed_send mlx_distributed_send_mlx_gen_orig_ #define mlx_distributed_sum_scatter mlx_distributed_sum_scatter_mlx_gen_orig_ +#define mlx_distributed_group_new mlx_distributed_group_new_mlx_gen_orig_ +#define mlx_distributed_group_free mlx_distributed_group_free_mlx_gen_orig_ +#define mlx_distributed_init mlx_distributed_init_mlx_gen_orig_ #define mlx_distributed_group_rank mlx_distributed_group_rank_mlx_gen_orig_ #define mlx_distributed_group_size mlx_distributed_group_size_mlx_gen_orig_ #define mlx_distributed_group_split mlx_distributed_group_split_mlx_gen_orig_ #define mlx_distributed_is_available mlx_distributed_is_available_mlx_gen_orig_ -#define mlx_distributed_init mlx_distributed_init_mlx_gen_orig_ #define mlx_set_error_handler mlx_set_error_handler_mlx_gen_orig_ #define _mlx_error _mlx_error_mlx_gen_orig_ #define mlx_export_function mlx_export_function_mlx_gen_orig_ @@ -193,6 +195,7 @@ #define mlx_fast_scaled_dot_product_attention mlx_fast_scaled_dot_product_attention_mlx_gen_orig_ #define mlx_fft_fft mlx_fft_fft_mlx_gen_orig_ #define mlx_fft_fft2 mlx_fft_fft2_mlx_gen_orig_ +#define mlx_fft_fftfreq mlx_fft_fftfreq_mlx_gen_orig_ #define mlx_fft_fftn mlx_fft_fftn_mlx_gen_orig_ #define mlx_fft_fftshift mlx_fft_fftshift_mlx_gen_orig_ #define mlx_fft_ifft mlx_fft_ifft_mlx_gen_orig_ @@ -204,13 +207,22 @@ #define mlx_fft_irfftn mlx_fft_irfftn_mlx_gen_orig_ #define mlx_fft_rfft mlx_fft_rfft_mlx_gen_orig_ #define mlx_fft_rfft2 mlx_fft_rfft2_mlx_gen_orig_ +#define mlx_fft_rfftfreq mlx_fft_rfftfreq_mlx_gen_orig_ #define mlx_fft_rfftn mlx_fft_rfftn_mlx_gen_orig_ +#define mlx_node_namer_new mlx_node_namer_new_mlx_gen_orig_ +#define mlx_node_namer_free mlx_node_namer_free_mlx_gen_orig_ +#define mlx_node_namer_set_name mlx_node_namer_set_name_mlx_gen_orig_ +#define mlx_node_namer_get_name mlx_node_namer_get_name_mlx_gen_orig_ +#define mlx_export_to_dot mlx_export_to_dot_mlx_gen_orig_ +#define mlx_print_graph mlx_print_graph_mlx_gen_orig_ #define mlx_load_reader mlx_load_reader_mlx_gen_orig_ #define mlx_load mlx_load_mlx_gen_orig_ +#define mlx_load_gguf mlx_load_gguf_mlx_gen_orig_ #define mlx_load_safetensors_reader mlx_load_safetensors_reader_mlx_gen_orig_ #define mlx_load_safetensors mlx_load_safetensors_mlx_gen_orig_ #define mlx_save_writer mlx_save_writer_mlx_gen_orig_ #define mlx_save mlx_save_mlx_gen_orig_ +#define mlx_save_gguf mlx_save_gguf_mlx_gen_orig_ #define mlx_save_safetensors_writer mlx_save_safetensors_writer_mlx_gen_orig_ #define mlx_save_safetensors mlx_save_safetensors_mlx_gen_orig_ #define mlx_io_reader_new mlx_io_reader_new_mlx_gen_orig_ @@ -221,6 +233,20 @@ #define mlx_io_writer_descriptor mlx_io_writer_descriptor_mlx_gen_orig_ #define mlx_io_writer_tostring mlx_io_writer_tostring_mlx_gen_orig_ #define mlx_io_writer_free mlx_io_writer_free_mlx_gen_orig_ +#define mlx_io_gguf_new mlx_io_gguf_new_mlx_gen_orig_ +#define mlx_io_gguf_free mlx_io_gguf_free_mlx_gen_orig_ +#define mlx_io_gguf_get_keys mlx_io_gguf_get_keys_mlx_gen_orig_ +#define mlx_io_gguf_get_array mlx_io_gguf_get_array_mlx_gen_orig_ +#define mlx_io_gguf_get_metadata_array mlx_io_gguf_get_metadata_array_mlx_gen_orig_ +#define mlx_io_gguf_get_metadata_string mlx_io_gguf_get_metadata_string_mlx_gen_orig_ +#define mlx_io_gguf_get_metadata_vector_string mlx_io_gguf_get_metadata_vector_string_mlx_gen_orig_ +#define mlx_io_gguf_has_metadata_array mlx_io_gguf_has_metadata_array_mlx_gen_orig_ +#define mlx_io_gguf_has_metadata_string mlx_io_gguf_has_metadata_string_mlx_gen_orig_ +#define mlx_io_gguf_has_metadata_vector_string mlx_io_gguf_has_metadata_vector_string_mlx_gen_orig_ +#define mlx_io_gguf_set_array mlx_io_gguf_set_array_mlx_gen_orig_ +#define mlx_io_gguf_set_metadata_array mlx_io_gguf_set_metadata_array_mlx_gen_orig_ +#define mlx_io_gguf_set_metadata_string mlx_io_gguf_set_metadata_string_mlx_gen_orig_ +#define mlx_io_gguf_set_metadata_vector_string mlx_io_gguf_set_metadata_vector_string_mlx_gen_orig_ #define mlx_linalg_cholesky mlx_linalg_cholesky_mlx_gen_orig_ #define mlx_linalg_cholesky_inv mlx_linalg_cholesky_inv_mlx_gen_orig_ #define mlx_linalg_cross mlx_linalg_cross_mlx_gen_orig_ @@ -455,6 +481,10 @@ #define mlx_slice_dynamic mlx_slice_dynamic_mlx_gen_orig_ #define mlx_slice_update mlx_slice_update_mlx_gen_orig_ #define mlx_slice_update_dynamic mlx_slice_update_dynamic_mlx_gen_orig_ +#define mlx_slice_update_add mlx_slice_update_add_mlx_gen_orig_ +#define mlx_slice_update_max mlx_slice_update_max_mlx_gen_orig_ +#define mlx_slice_update_min mlx_slice_update_min_mlx_gen_orig_ +#define mlx_slice_update_prod mlx_slice_update_prod_mlx_gen_orig_ #define mlx_softmax_axes mlx_softmax_axes_mlx_gen_orig_ #define mlx_softmax_axis mlx_softmax_axis_mlx_gen_orig_ #define mlx_softmax mlx_softmax_mlx_gen_orig_ @@ -736,11 +766,13 @@ #undef mlx_distributed_recv_like #undef mlx_distributed_send #undef mlx_distributed_sum_scatter +#undef mlx_distributed_group_new +#undef mlx_distributed_group_free +#undef mlx_distributed_init #undef mlx_distributed_group_rank #undef mlx_distributed_group_size #undef mlx_distributed_group_split #undef mlx_distributed_is_available -#undef mlx_distributed_init #undef mlx_set_error_handler #undef _mlx_error #undef mlx_export_function @@ -786,6 +818,7 @@ #undef mlx_fast_scaled_dot_product_attention #undef mlx_fft_fft #undef mlx_fft_fft2 +#undef mlx_fft_fftfreq #undef mlx_fft_fftn #undef mlx_fft_fftshift #undef mlx_fft_ifft @@ -797,13 +830,22 @@ #undef mlx_fft_irfftn #undef mlx_fft_rfft #undef mlx_fft_rfft2 +#undef mlx_fft_rfftfreq #undef mlx_fft_rfftn +#undef mlx_node_namer_new +#undef mlx_node_namer_free +#undef mlx_node_namer_set_name +#undef mlx_node_namer_get_name +#undef mlx_export_to_dot +#undef mlx_print_graph #undef mlx_load_reader #undef mlx_load +#undef mlx_load_gguf #undef mlx_load_safetensors_reader #undef mlx_load_safetensors #undef mlx_save_writer #undef mlx_save +#undef mlx_save_gguf #undef mlx_save_safetensors_writer #undef mlx_save_safetensors #undef mlx_io_reader_new @@ -814,6 +856,20 @@ #undef mlx_io_writer_descriptor #undef mlx_io_writer_tostring #undef mlx_io_writer_free +#undef mlx_io_gguf_new +#undef mlx_io_gguf_free +#undef mlx_io_gguf_get_keys +#undef mlx_io_gguf_get_array +#undef mlx_io_gguf_get_metadata_array +#undef mlx_io_gguf_get_metadata_string +#undef mlx_io_gguf_get_metadata_vector_string +#undef mlx_io_gguf_has_metadata_array +#undef mlx_io_gguf_has_metadata_string +#undef mlx_io_gguf_has_metadata_vector_string +#undef mlx_io_gguf_set_array +#undef mlx_io_gguf_set_metadata_array +#undef mlx_io_gguf_set_metadata_string +#undef mlx_io_gguf_set_metadata_vector_string #undef mlx_linalg_cholesky #undef mlx_linalg_cholesky_inv #undef mlx_linalg_cross @@ -1048,6 +1104,10 @@ #undef mlx_slice_dynamic #undef mlx_slice_update #undef mlx_slice_update_dynamic +#undef mlx_slice_update_add +#undef mlx_slice_update_max +#undef mlx_slice_update_min +#undef mlx_slice_update_prod #undef mlx_softmax_axes #undef mlx_softmax_axis #undef mlx_softmax @@ -1506,13 +1566,20 @@ extern int (*mlx_distributed_sum_scatter_)( const mlx_array x, const mlx_distributed_group group /* may be null */, const mlx_stream s); -extern int (*mlx_distributed_group_rank_)(mlx_distributed_group group); -extern int (*mlx_distributed_group_size_)(mlx_distributed_group group); -extern mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key); -extern bool (*mlx_distributed_is_available_)(const char* bk /* may be null */); -extern mlx_distributed_group (*mlx_distributed_init_)( +extern mlx_distributed_group (*mlx_distributed_group_new_)(void); +extern int (*mlx_distributed_group_free_)(mlx_distributed_group group); +extern int (*mlx_distributed_init_)( + mlx_distributed_group* res, bool strict, const char* bk /* may be null */); +extern int (*mlx_distributed_group_rank_)(mlx_distributed_group group); +extern int (*mlx_distributed_group_size_)(mlx_distributed_group group); +extern int (*mlx_distributed_group_split_)( + mlx_distributed_group* res, + mlx_distributed_group group, + int color, + int key); +extern bool (*mlx_distributed_is_available_)(const char* bk /* may be null */); extern void (*mlx_set_error_handler_)( mlx_error_handler_func handler, void* data, @@ -1700,6 +1767,7 @@ extern int (*mlx_fft_fft_)( const mlx_array a, int n, int axis, + mlx_fft_norm norm, const mlx_stream s); extern int (*mlx_fft_fft2_)( mlx_array* res, @@ -1708,7 +1776,9 @@ extern int (*mlx_fft_fft2_)( size_t n_num, const int* axes, size_t axes_num, + mlx_fft_norm norm, const mlx_stream s); +extern int (*mlx_fft_fftfreq_)(mlx_array* res, int n, double d, const mlx_stream s); extern int (*mlx_fft_fftn_)( mlx_array* res, const mlx_array a, @@ -1716,6 +1786,7 @@ extern int (*mlx_fft_fftn_)( size_t n_num, const int* axes, size_t axes_num, + mlx_fft_norm norm, const mlx_stream s); extern int (*mlx_fft_fftshift_)( mlx_array* res, @@ -1728,6 +1799,7 @@ extern int (*mlx_fft_ifft_)( const mlx_array a, int n, int axis, + mlx_fft_norm norm, const mlx_stream s); extern int (*mlx_fft_ifft2_)( mlx_array* res, @@ -1736,6 +1808,7 @@ extern int (*mlx_fft_ifft2_)( size_t n_num, const int* axes, size_t axes_num, + mlx_fft_norm norm, const mlx_stream s); extern int (*mlx_fft_ifftn_)( mlx_array* res, @@ -1744,6 +1817,7 @@ extern int (*mlx_fft_ifftn_)( size_t n_num, const int* axes, size_t axes_num, + mlx_fft_norm norm, const mlx_stream s); extern int (*mlx_fft_ifftshift_)( mlx_array* res, @@ -1756,6 +1830,7 @@ extern int (*mlx_fft_irfft_)( const mlx_array a, int n, int axis, + mlx_fft_norm norm, const mlx_stream s); extern int (*mlx_fft_irfft2_)( mlx_array* res, @@ -1764,6 +1839,7 @@ extern int (*mlx_fft_irfft2_)( size_t n_num, const int* axes, size_t axes_num, + mlx_fft_norm norm, const mlx_stream s); extern int (*mlx_fft_irfftn_)( mlx_array* res, @@ -1772,12 +1848,14 @@ extern int (*mlx_fft_irfftn_)( size_t n_num, const int* axes, size_t axes_num, + mlx_fft_norm norm, const mlx_stream s); extern int (*mlx_fft_rfft_)( mlx_array* res, const mlx_array a, int n, int axis, + mlx_fft_norm norm, const mlx_stream s); extern int (*mlx_fft_rfft2_)( mlx_array* res, @@ -1786,7 +1864,9 @@ extern int (*mlx_fft_rfft2_)( size_t n_num, const int* axes, size_t axes_num, + mlx_fft_norm norm, const mlx_stream s); +extern int (*mlx_fft_rfftfreq_)(mlx_array* res, int n, double d, const mlx_stream s); extern int (*mlx_fft_rfftn_)( mlx_array* res, const mlx_array a, @@ -1794,12 +1874,32 @@ extern int (*mlx_fft_rfftn_)( size_t n_num, const int* axes, size_t axes_num, + mlx_fft_norm norm, const mlx_stream s); +extern mlx_node_namer (*mlx_node_namer_new_)(); +extern int (*mlx_node_namer_free_)(mlx_node_namer namer); +extern int (*mlx_node_namer_set_name_)( + mlx_node_namer namer, + const mlx_array arr, + const char* name); +extern int (*mlx_node_namer_get_name_)( + const char** name, + mlx_node_namer namer, + const mlx_array arr); +extern int (*mlx_export_to_dot_)( + FILE* os, + const mlx_node_namer namer, + const mlx_vector_array outputs); +extern int (*mlx_print_graph_)( + FILE* os, + const mlx_node_namer namer, + const mlx_vector_array outputs); extern int (*mlx_load_reader_)( mlx_array* res, mlx_io_reader in_stream, const mlx_stream s); extern int (*mlx_load_)(mlx_array* res, const char* file, const mlx_stream s); +extern int (*mlx_load_gguf_)(mlx_io_gguf* gguf, const char* file, const mlx_stream s); extern int (*mlx_load_safetensors_reader_)( mlx_map_string_to_array* res_0, mlx_map_string_to_string* res_1, @@ -1812,6 +1912,7 @@ extern int (*mlx_load_safetensors_)( const mlx_stream s); extern int (*mlx_save_writer_)(mlx_io_writer out_stream, const mlx_array a); extern int (*mlx_save_)(const char* file, const mlx_array a); +extern int (*mlx_save_gguf_)(const char* file, mlx_io_gguf gguf); extern int (*mlx_save_safetensors_writer_)( mlx_io_writer in_stream, const mlx_map_string_to_array param, @@ -1828,6 +1929,44 @@ extern mlx_io_writer (*mlx_io_writer_new_)(void* desc, mlx_io_vtable vtable); extern int (*mlx_io_writer_descriptor_)(void** desc_, mlx_io_writer io); extern int (*mlx_io_writer_tostring_)(mlx_string* str_, mlx_io_writer io); extern int (*mlx_io_writer_free_)(mlx_io_writer io); +extern mlx_io_gguf (*mlx_io_gguf_new_)(void); +extern int (*mlx_io_gguf_free_)(mlx_io_gguf io); +extern int (*mlx_io_gguf_get_keys_)(mlx_vector_string* keys, mlx_io_gguf io); +extern int (*mlx_io_gguf_get_array_)(mlx_array* arr, mlx_io_gguf io, const char* key); +extern int (*mlx_io_gguf_get_metadata_array_)( + mlx_array* arr, + mlx_io_gguf io, + const char* key); +extern int (*mlx_io_gguf_get_metadata_string_)( + mlx_string* str, + mlx_io_gguf io, + const char* key); +extern int (*mlx_io_gguf_get_metadata_vector_string_)( + mlx_vector_string* vstr, + mlx_io_gguf io, + const char* key); +extern int (*mlx_io_gguf_has_metadata_array_)(bool* flag, mlx_io_gguf io, const char* key); +extern int (*mlx_io_gguf_has_metadata_string_)( + bool* flag, + mlx_io_gguf io, + const char* key); +extern int (*mlx_io_gguf_has_metadata_vector_string_)( + bool* flag, + mlx_io_gguf io, + const char* key); +extern int (*mlx_io_gguf_set_array_)(mlx_io_gguf io, const char* key, const mlx_array arr); +extern int (*mlx_io_gguf_set_metadata_array_)( + mlx_io_gguf io, + const char* key, + const mlx_array marr); +extern int (*mlx_io_gguf_set_metadata_string_)( + mlx_io_gguf io, + const char* key, + const char* mstr); +extern int (*mlx_io_gguf_set_metadata_vector_string_)( + mlx_io_gguf io, + const char* key, + const mlx_vector_string mvstr); extern int (*mlx_linalg_cholesky_)( mlx_array* res, const mlx_array a, @@ -2947,6 +3086,50 @@ extern int (*mlx_slice_update_dynamic_)( const int* axes, size_t axes_num, const mlx_stream s); +extern int (*mlx_slice_update_add_)( + mlx_array* res, + const mlx_array src, + const mlx_array update, + const int* start, + size_t start_num, + const int* stop, + size_t stop_num, + const int* strides, + size_t strides_num, + const mlx_stream s); +extern int (*mlx_slice_update_max_)( + mlx_array* res, + const mlx_array src, + const mlx_array update, + const int* start, + size_t start_num, + const int* stop, + size_t stop_num, + const int* strides, + size_t strides_num, + const mlx_stream s); +extern int (*mlx_slice_update_min_)( + mlx_array* res, + const mlx_array src, + const mlx_array update, + const int* start, + size_t start_num, + const int* stop, + size_t stop_num, + const int* strides, + size_t strides_num, + const mlx_stream s); +extern int (*mlx_slice_update_prod_)( + mlx_array* res, + const mlx_array src, + const mlx_array update, + const int* start, + size_t start_num, + const int* stop, + size_t stop_num, + const int* strides, + size_t strides_num, + const mlx_stream s); extern int (*mlx_softmax_axes_)( mlx_array* res, const mlx_array a, @@ -4042,23 +4225,34 @@ static inline int mlx_distributed_sum_scatter( const mlx_stream s) { return mlx_distributed_sum_scatter_(res, x, group, s); } +static inline mlx_distributed_group mlx_distributed_group_new(void) { + return mlx_distributed_group_new_(); +} +static inline int mlx_distributed_group_free(mlx_distributed_group group) { + return mlx_distributed_group_free_(group); +} +static inline int mlx_distributed_init( + mlx_distributed_group* res, + bool strict, + const char* bk /* may be null */) { + return mlx_distributed_init_(res, strict, bk); +} static inline int mlx_distributed_group_rank(mlx_distributed_group group) { return mlx_distributed_group_rank_(group); } static inline int mlx_distributed_group_size(mlx_distributed_group group) { return mlx_distributed_group_size_(group); } -static inline mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key) { - return mlx_distributed_group_split_(group, color, key); +static inline int mlx_distributed_group_split( + mlx_distributed_group* res, + mlx_distributed_group group, + int color, + int key) { + return mlx_distributed_group_split_(res, group, color, key); } static inline bool mlx_distributed_is_available(const char* bk /* may be null */) { return mlx_distributed_is_available_(bk); } -static inline mlx_distributed_group mlx_distributed_init( - bool strict, - const char* bk /* may be null */) { - return mlx_distributed_init_(strict, bk); -} static inline void mlx_set_error_handler( mlx_error_handler_func handler, void* data, @@ -4332,8 +4526,9 @@ static inline int mlx_fft_fft( const mlx_array a, int n, int axis, + mlx_fft_norm norm, const mlx_stream s) { - return mlx_fft_fft_(res, a, n, axis, s); + return mlx_fft_fft_(res, a, n, axis, norm, s); } static inline int mlx_fft_fft2( mlx_array* res, @@ -4342,8 +4537,12 @@ static inline int mlx_fft_fft2( size_t n_num, const int* axes, size_t axes_num, + mlx_fft_norm norm, const mlx_stream s) { - return mlx_fft_fft2_(res, a, n, n_num, axes, axes_num, s); + return mlx_fft_fft2_(res, a, n, n_num, axes, axes_num, norm, s); +} +static inline int mlx_fft_fftfreq(mlx_array* res, int n, double d, const mlx_stream s) { + return mlx_fft_fftfreq_(res, n, d, s); } static inline int mlx_fft_fftn( mlx_array* res, @@ -4352,8 +4551,9 @@ static inline int mlx_fft_fftn( size_t n_num, const int* axes, size_t axes_num, + mlx_fft_norm norm, const mlx_stream s) { - return mlx_fft_fftn_(res, a, n, n_num, axes, axes_num, s); + return mlx_fft_fftn_(res, a, n, n_num, axes, axes_num, norm, s); } static inline int mlx_fft_fftshift( mlx_array* res, @@ -4368,8 +4568,9 @@ static inline int mlx_fft_ifft( const mlx_array a, int n, int axis, + mlx_fft_norm norm, const mlx_stream s) { - return mlx_fft_ifft_(res, a, n, axis, s); + return mlx_fft_ifft_(res, a, n, axis, norm, s); } static inline int mlx_fft_ifft2( mlx_array* res, @@ -4378,8 +4579,9 @@ static inline int mlx_fft_ifft2( size_t n_num, const int* axes, size_t axes_num, + mlx_fft_norm norm, const mlx_stream s) { - return mlx_fft_ifft2_(res, a, n, n_num, axes, axes_num, s); + return mlx_fft_ifft2_(res, a, n, n_num, axes, axes_num, norm, s); } static inline int mlx_fft_ifftn( mlx_array* res, @@ -4388,8 +4590,9 @@ static inline int mlx_fft_ifftn( size_t n_num, const int* axes, size_t axes_num, + mlx_fft_norm norm, const mlx_stream s) { - return mlx_fft_ifftn_(res, a, n, n_num, axes, axes_num, s); + return mlx_fft_ifftn_(res, a, n, n_num, axes, axes_num, norm, s); } static inline int mlx_fft_ifftshift( mlx_array* res, @@ -4404,8 +4607,9 @@ static inline int mlx_fft_irfft( const mlx_array a, int n, int axis, + mlx_fft_norm norm, const mlx_stream s) { - return mlx_fft_irfft_(res, a, n, axis, s); + return mlx_fft_irfft_(res, a, n, axis, norm, s); } static inline int mlx_fft_irfft2( mlx_array* res, @@ -4414,8 +4618,9 @@ static inline int mlx_fft_irfft2( size_t n_num, const int* axes, size_t axes_num, + mlx_fft_norm norm, const mlx_stream s) { - return mlx_fft_irfft2_(res, a, n, n_num, axes, axes_num, s); + return mlx_fft_irfft2_(res, a, n, n_num, axes, axes_num, norm, s); } static inline int mlx_fft_irfftn( mlx_array* res, @@ -4424,16 +4629,18 @@ static inline int mlx_fft_irfftn( size_t n_num, const int* axes, size_t axes_num, + mlx_fft_norm norm, const mlx_stream s) { - return mlx_fft_irfftn_(res, a, n, n_num, axes, axes_num, s); + return mlx_fft_irfftn_(res, a, n, n_num, axes, axes_num, norm, s); } static inline int mlx_fft_rfft( mlx_array* res, const mlx_array a, int n, int axis, + mlx_fft_norm norm, const mlx_stream s) { - return mlx_fft_rfft_(res, a, n, axis, s); + return mlx_fft_rfft_(res, a, n, axis, norm, s); } static inline int mlx_fft_rfft2( mlx_array* res, @@ -4442,8 +4649,12 @@ static inline int mlx_fft_rfft2( size_t n_num, const int* axes, size_t axes_num, + mlx_fft_norm norm, const mlx_stream s) { - return mlx_fft_rfft2_(res, a, n, n_num, axes, axes_num, s); + return mlx_fft_rfft2_(res, a, n, n_num, axes, axes_num, norm, s); +} +static inline int mlx_fft_rfftfreq(mlx_array* res, int n, double d, const mlx_stream s) { + return mlx_fft_rfftfreq_(res, n, d, s); } static inline int mlx_fft_rfftn( mlx_array* res, @@ -4452,8 +4663,39 @@ static inline int mlx_fft_rfftn( size_t n_num, const int* axes, size_t axes_num, + mlx_fft_norm norm, const mlx_stream s) { - return mlx_fft_rfftn_(res, a, n, n_num, axes, axes_num, s); + return mlx_fft_rfftn_(res, a, n, n_num, axes, axes_num, norm, s); +} +static inline mlx_node_namer mlx_node_namer_new() { + return mlx_node_namer_new_(); +} +static inline int mlx_node_namer_free(mlx_node_namer namer) { + return mlx_node_namer_free_(namer); +} +static inline int mlx_node_namer_set_name( + mlx_node_namer namer, + const mlx_array arr, + const char* name) { + return mlx_node_namer_set_name_(namer, arr, name); +} +static inline int mlx_node_namer_get_name( + const char** name, + mlx_node_namer namer, + const mlx_array arr) { + return mlx_node_namer_get_name_(name, namer, arr); +} +static inline int mlx_export_to_dot( + FILE* os, + const mlx_node_namer namer, + const mlx_vector_array outputs) { + return mlx_export_to_dot_(os, namer, outputs); +} +static inline int mlx_print_graph( + FILE* os, + const mlx_node_namer namer, + const mlx_vector_array outputs) { + return mlx_print_graph_(os, namer, outputs); } static inline int mlx_load_reader( mlx_array* res, @@ -4464,6 +4706,9 @@ static inline int mlx_load_reader( static inline int mlx_load(mlx_array* res, const char* file, const mlx_stream s) { return mlx_load_(res, file, s); } +static inline int mlx_load_gguf(mlx_io_gguf* gguf, const char* file, const mlx_stream s) { + return mlx_load_gguf_(gguf, file, s); +} static inline int mlx_load_safetensors_reader( mlx_map_string_to_array* res_0, mlx_map_string_to_string* res_1, @@ -4484,6 +4729,9 @@ static inline int mlx_save_writer(mlx_io_writer out_stream, const mlx_array a) { static inline int mlx_save(const char* file, const mlx_array a) { return mlx_save_(file, a); } +static inline int mlx_save_gguf(const char* file, mlx_io_gguf gguf) { + return mlx_save_gguf_(file, gguf); +} static inline int mlx_save_safetensors_writer( mlx_io_writer in_stream, const mlx_map_string_to_array param, @@ -4520,6 +4768,72 @@ static inline int mlx_io_writer_tostring(mlx_string* str_, mlx_io_writer io) { static inline int mlx_io_writer_free(mlx_io_writer io) { return mlx_io_writer_free_(io); } +static inline mlx_io_gguf mlx_io_gguf_new(void) { + return mlx_io_gguf_new_(); +} +static inline int mlx_io_gguf_free(mlx_io_gguf io) { + return mlx_io_gguf_free_(io); +} +static inline int mlx_io_gguf_get_keys(mlx_vector_string* keys, mlx_io_gguf io) { + return mlx_io_gguf_get_keys_(keys, io); +} +static inline int mlx_io_gguf_get_array(mlx_array* arr, mlx_io_gguf io, const char* key) { + return mlx_io_gguf_get_array_(arr, io, key); +} +static inline int mlx_io_gguf_get_metadata_array( + mlx_array* arr, + mlx_io_gguf io, + const char* key) { + return mlx_io_gguf_get_metadata_array_(arr, io, key); +} +static inline int mlx_io_gguf_get_metadata_string( + mlx_string* str, + mlx_io_gguf io, + const char* key) { + return mlx_io_gguf_get_metadata_string_(str, io, key); +} +static inline int mlx_io_gguf_get_metadata_vector_string( + mlx_vector_string* vstr, + mlx_io_gguf io, + const char* key) { + return mlx_io_gguf_get_metadata_vector_string_(vstr, io, key); +} +static inline int mlx_io_gguf_has_metadata_array(bool* flag, mlx_io_gguf io, const char* key) { + return mlx_io_gguf_has_metadata_array_(flag, io, key); +} +static inline int mlx_io_gguf_has_metadata_string( + bool* flag, + mlx_io_gguf io, + const char* key) { + return mlx_io_gguf_has_metadata_string_(flag, io, key); +} +static inline int mlx_io_gguf_has_metadata_vector_string( + bool* flag, + mlx_io_gguf io, + const char* key) { + return mlx_io_gguf_has_metadata_vector_string_(flag, io, key); +} +static inline int mlx_io_gguf_set_array(mlx_io_gguf io, const char* key, const mlx_array arr) { + return mlx_io_gguf_set_array_(io, key, arr); +} +static inline int mlx_io_gguf_set_metadata_array( + mlx_io_gguf io, + const char* key, + const mlx_array marr) { + return mlx_io_gguf_set_metadata_array_(io, key, marr); +} +static inline int mlx_io_gguf_set_metadata_string( + mlx_io_gguf io, + const char* key, + const char* mstr) { + return mlx_io_gguf_set_metadata_string_(io, key, mstr); +} +static inline int mlx_io_gguf_set_metadata_vector_string( + mlx_io_gguf io, + const char* key, + const mlx_vector_string mvstr) { + return mlx_io_gguf_set_metadata_vector_string_(io, key, mvstr); +} static inline int mlx_linalg_cholesky( mlx_array* res, const mlx_array a, @@ -6107,6 +6421,58 @@ static inline int mlx_slice_update_dynamic( const mlx_stream s) { return mlx_slice_update_dynamic_(res, src, update, start, axes, axes_num, s); } +static inline int mlx_slice_update_add( + mlx_array* res, + const mlx_array src, + const mlx_array update, + const int* start, + size_t start_num, + const int* stop, + size_t stop_num, + const int* strides, + size_t strides_num, + const mlx_stream s) { + return mlx_slice_update_add_(res, src, update, start, start_num, stop, stop_num, strides, strides_num, s); +} +static inline int mlx_slice_update_max( + mlx_array* res, + const mlx_array src, + const mlx_array update, + const int* start, + size_t start_num, + const int* stop, + size_t stop_num, + const int* strides, + size_t strides_num, + const mlx_stream s) { + return mlx_slice_update_max_(res, src, update, start, start_num, stop, stop_num, strides, strides_num, s); +} +static inline int mlx_slice_update_min( + mlx_array* res, + const mlx_array src, + const mlx_array update, + const int* start, + size_t start_num, + const int* stop, + size_t stop_num, + const int* strides, + size_t strides_num, + const mlx_stream s) { + return mlx_slice_update_min_(res, src, update, start, start_num, stop, stop_num, strides, strides_num, s); +} +static inline int mlx_slice_update_prod( + mlx_array* res, + const mlx_array src, + const mlx_array update, + const int* start, + size_t start_num, + const int* stop, + size_t stop_num, + const int* strides, + size_t strides_num, + const mlx_stream s) { + return mlx_slice_update_prod_(res, src, update, start, start_num, stop, stop_num, strides, strides_num, s); +} static inline int mlx_softmax_axes( mlx_array* res, const mlx_array a, diff --git a/x/mlxrunner/mlx/include/mlx/c/compile.h b/x/mlxrunner/mlx/include/mlx/c/compile.h index 04567fb3a..2892a1f0b 100644 --- a/x/mlxrunner/mlx/include/mlx/c/compile.h +++ b/x/mlxrunner/mlx/include/mlx/c/compile.h @@ -34,6 +34,7 @@ typedef enum mlx_compile_mode_ { MLX_COMPILE_MODE_NO_FUSE, MLX_COMPILE_MODE_ENABLED } mlx_compile_mode; + int mlx_compile(mlx_closure* res, const mlx_closure fun, bool shapeless); int mlx_detail_compile( mlx_closure* res, diff --git a/x/mlxrunner/mlx/include/mlx/c/distributed_group.h b/x/mlxrunner/mlx/include/mlx/c/distributed_group.h index 43aa2ae56..bfbaa80eb 100644 --- a/x/mlxrunner/mlx/include/mlx/c/distributed_group.h +++ b/x/mlxrunner/mlx/include/mlx/c/distributed_group.h @@ -23,6 +23,24 @@ typedef struct mlx_distributed_group_ { void* ctx; } mlx_distributed_group; +/** + * Create an empty group. + */ +mlx_distributed_group mlx_distributed_group_new(void); + +/** + * Free the group. + */ +int mlx_distributed_group_free(mlx_distributed_group group); + +/** + * Initialize distributed. + */ +int mlx_distributed_init( + mlx_distributed_group* res, + bool strict, + const char* bk /* may be null */); + /** * Get the rank. */ @@ -36,21 +54,17 @@ int mlx_distributed_group_size(mlx_distributed_group group); /** * Split the group. */ -mlx_distributed_group -mlx_distributed_group_split(mlx_distributed_group group, int color, int key); +int mlx_distributed_group_split( + mlx_distributed_group* res, + mlx_distributed_group group, + int color, + int key); /** * Check if distributed is available. */ bool mlx_distributed_is_available(const char* bk /* may be null */); -/** - * Initialize distributed. - */ -mlx_distributed_group mlx_distributed_init( - bool strict, - const char* bk /* may be null */); - /**@}*/ #ifdef __cplusplus diff --git a/x/mlxrunner/mlx/include/mlx/c/fft.h b/x/mlxrunner/mlx/include/mlx/c/fft.h index 779803e9b..7140b601f 100644 --- a/x/mlxrunner/mlx/include/mlx/c/fft.h +++ b/x/mlxrunner/mlx/include/mlx/c/fft.h @@ -28,11 +28,18 @@ extern "C" { */ /**@{*/ +typedef enum mlx_fft_norm_ { + MLX_FFT_NORM_BACKWARD, + MLX_FFT_NORM_ORTHO, + MLX_FFT_NORM_FORWARD +} mlx_fft_norm; + int mlx_fft_fft( mlx_array* res, const mlx_array a, int n, int axis, + mlx_fft_norm norm, const mlx_stream s); int mlx_fft_fft2( mlx_array* res, @@ -41,7 +48,9 @@ int mlx_fft_fft2( size_t n_num, const int* axes, size_t axes_num, + mlx_fft_norm norm, const mlx_stream s); +int mlx_fft_fftfreq(mlx_array* res, int n, double d, const mlx_stream s); int mlx_fft_fftn( mlx_array* res, const mlx_array a, @@ -49,6 +58,7 @@ int mlx_fft_fftn( size_t n_num, const int* axes, size_t axes_num, + mlx_fft_norm norm, const mlx_stream s); int mlx_fft_fftshift( mlx_array* res, @@ -61,6 +71,7 @@ int mlx_fft_ifft( const mlx_array a, int n, int axis, + mlx_fft_norm norm, const mlx_stream s); int mlx_fft_ifft2( mlx_array* res, @@ -69,6 +80,7 @@ int mlx_fft_ifft2( size_t n_num, const int* axes, size_t axes_num, + mlx_fft_norm norm, const mlx_stream s); int mlx_fft_ifftn( mlx_array* res, @@ -77,6 +89,7 @@ int mlx_fft_ifftn( size_t n_num, const int* axes, size_t axes_num, + mlx_fft_norm norm, const mlx_stream s); int mlx_fft_ifftshift( mlx_array* res, @@ -89,6 +102,7 @@ int mlx_fft_irfft( const mlx_array a, int n, int axis, + mlx_fft_norm norm, const mlx_stream s); int mlx_fft_irfft2( mlx_array* res, @@ -97,6 +111,7 @@ int mlx_fft_irfft2( size_t n_num, const int* axes, size_t axes_num, + mlx_fft_norm norm, const mlx_stream s); int mlx_fft_irfftn( mlx_array* res, @@ -105,12 +120,14 @@ int mlx_fft_irfftn( size_t n_num, const int* axes, size_t axes_num, + mlx_fft_norm norm, const mlx_stream s); int mlx_fft_rfft( mlx_array* res, const mlx_array a, int n, int axis, + mlx_fft_norm norm, const mlx_stream s); int mlx_fft_rfft2( mlx_array* res, @@ -119,7 +136,9 @@ int mlx_fft_rfft2( size_t n_num, const int* axes, size_t axes_num, + mlx_fft_norm norm, const mlx_stream s); +int mlx_fft_rfftfreq(mlx_array* res, int n, double d, const mlx_stream s); int mlx_fft_rfftn( mlx_array* res, const mlx_array a, @@ -127,6 +146,7 @@ int mlx_fft_rfftn( size_t n_num, const int* axes, size_t axes_num, + mlx_fft_norm norm, const mlx_stream s); /**@}*/ diff --git a/x/mlxrunner/mlx/include/mlx/c/graph_utils.h b/x/mlxrunner/mlx/include/mlx/c/graph_utils.h new file mode 100644 index 000000000..81eec7a87 --- /dev/null +++ b/x/mlxrunner/mlx/include/mlx/c/graph_utils.h @@ -0,0 +1,61 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_GRAPH_UTILS_H +#define MLX_GRAPH_UTILS_H + +#include +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/io_types.h" +#include "mlx/c/map.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup graph_utils Graph Utils + */ +/**@{*/ + +typedef struct mlx_node_namer_ { + void* ctx; +} mlx_node_namer; + +mlx_node_namer mlx_node_namer_new(); +int mlx_node_namer_free(mlx_node_namer namer); +int mlx_node_namer_set_name( + mlx_node_namer namer, + const mlx_array arr, + const char* name); +int mlx_node_namer_get_name( + const char** name, + mlx_node_namer namer, + const mlx_array arr); + +int mlx_export_to_dot( + FILE* os, + const mlx_node_namer namer, + const mlx_vector_array outputs); +int mlx_print_graph( + FILE* os, + const mlx_node_namer namer, + const mlx_vector_array outputs); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/x/mlxrunner/mlx/include/mlx/c/io.h b/x/mlxrunner/mlx/include/mlx/c/io.h index 6eb205c9a..e1a0c0a7d 100644 --- a/x/mlxrunner/mlx/include/mlx/c/io.h +++ b/x/mlxrunner/mlx/include/mlx/c/io.h @@ -33,6 +33,9 @@ int mlx_load_reader( mlx_io_reader in_stream, const mlx_stream s); int mlx_load(mlx_array* res, const char* file, const mlx_stream s); + +int mlx_load_gguf(mlx_io_gguf* gguf, const char* file, const mlx_stream s); + int mlx_load_safetensors_reader( mlx_map_string_to_array* res_0, mlx_map_string_to_string* res_1, @@ -45,6 +48,8 @@ int mlx_load_safetensors( const mlx_stream s); int mlx_save_writer(mlx_io_writer out_stream, const mlx_array a); int mlx_save(const char* file, const mlx_array a); +int mlx_save_gguf(const char* file, mlx_io_gguf gguf); + int mlx_save_safetensors_writer( mlx_io_writer in_stream, const mlx_map_string_to_array param, diff --git a/x/mlxrunner/mlx/include/mlx/c/io_types.h b/x/mlxrunner/mlx/include/mlx/c/io_types.h index 88349b57c..5382e7342 100644 --- a/x/mlxrunner/mlx/include/mlx/c/io_types.h +++ b/x/mlxrunner/mlx/include/mlx/c/io_types.h @@ -95,6 +95,52 @@ int mlx_io_writer_tostring(mlx_string* str_, mlx_io_writer io); */ int mlx_io_writer_free(mlx_io_writer io); +/** + * A MLX GGUF object. + */ +typedef struct mlx_io_gguf_ { + void* ctx; +} mlx_io_gguf; + +mlx_io_gguf mlx_io_gguf_new(void); +int mlx_io_gguf_free(mlx_io_gguf io); +int mlx_io_gguf_get_keys(mlx_vector_string* keys, mlx_io_gguf io); +int mlx_io_gguf_get_array(mlx_array* arr, mlx_io_gguf io, const char* key); +int mlx_io_gguf_get_metadata_array( + mlx_array* arr, + mlx_io_gguf io, + const char* key); +int mlx_io_gguf_get_metadata_string( + mlx_string* str, + mlx_io_gguf io, + const char* key); +int mlx_io_gguf_get_metadata_vector_string( + mlx_vector_string* vstr, + mlx_io_gguf io, + const char* key); +int mlx_io_gguf_has_metadata_array(bool* flag, mlx_io_gguf io, const char* key); +int mlx_io_gguf_has_metadata_string( + bool* flag, + mlx_io_gguf io, + const char* key); +int mlx_io_gguf_has_metadata_vector_string( + bool* flag, + mlx_io_gguf io, + const char* key); +int mlx_io_gguf_set_array(mlx_io_gguf io, const char* key, const mlx_array arr); +int mlx_io_gguf_set_metadata_array( + mlx_io_gguf io, + const char* key, + const mlx_array marr); +int mlx_io_gguf_set_metadata_string( + mlx_io_gguf io, + const char* key, + const char* mstr); +int mlx_io_gguf_set_metadata_vector_string( + mlx_io_gguf io, + const char* key, + const mlx_vector_string mvstr); + /**@}*/ #ifdef __cplusplus diff --git a/x/mlxrunner/mlx/include/mlx/c/mlx.h b/x/mlxrunner/mlx/include/mlx/c/mlx.h index ffadac89a..2aa9077c6 100644 --- a/x/mlxrunner/mlx/include/mlx/c/mlx.h +++ b/x/mlxrunner/mlx/include/mlx/c/mlx.h @@ -14,6 +14,7 @@ #include "mlx/c/export.h" #include "mlx/c/fast.h" #include "mlx/c/fft.h" +#include "mlx/c/graph_utils.h" #include "mlx/c/half.h" #include "mlx/c/io.h" #include "mlx/c/io_types.h" diff --git a/x/mlxrunner/mlx/include/mlx/c/ops.h b/x/mlxrunner/mlx/include/mlx/c/ops.h index 64d70e2f4..44fc09c4d 100644 --- a/x/mlxrunner/mlx/include/mlx/c/ops.h +++ b/x/mlxrunner/mlx/include/mlx/c/ops.h @@ -1004,6 +1004,50 @@ int mlx_slice_update_dynamic( const int* axes, size_t axes_num, const mlx_stream s); +int mlx_slice_update_add( + mlx_array* res, + const mlx_array src, + const mlx_array update, + const int* start, + size_t start_num, + const int* stop, + size_t stop_num, + const int* strides, + size_t strides_num, + const mlx_stream s); +int mlx_slice_update_max( + mlx_array* res, + const mlx_array src, + const mlx_array update, + const int* start, + size_t start_num, + const int* stop, + size_t stop_num, + const int* strides, + size_t strides_num, + const mlx_stream s); +int mlx_slice_update_min( + mlx_array* res, + const mlx_array src, + const mlx_array update, + const int* start, + size_t start_num, + const int* stop, + size_t stop_num, + const int* strides, + size_t strides_num, + const mlx_stream s); +int mlx_slice_update_prod( + mlx_array* res, + const mlx_array src, + const mlx_array update, + const int* start, + size_t start_num, + const int* stop, + size_t stop_num, + const int* strides, + size_t strides_num, + const mlx_stream s); int mlx_softmax_axes( mlx_array* res, const mlx_array a, diff --git a/x/mlxrunner/mlx/stream.go b/x/mlxrunner/mlx/stream.go index 9b01b4a85..fa3c112dc 100644 --- a/x/mlxrunner/mlx/stream.go +++ b/x/mlxrunner/mlx/stream.go @@ -3,10 +3,7 @@ package mlx // #include "generated.h" import "C" -import ( - "log/slog" - "sync" -) +import "log/slog" type Device struct { ctx C.mlx_device @@ -19,11 +16,28 @@ func (d Device) LogValue() slog.Value { return slog.StringValue(C.GoString(C.mlx_string_data(str))) } -var DefaultDevice = sync.OnceValue(func() Device { - d := C.mlx_device_new() - C.mlx_get_default_device(&d) - return Device{d} -}) +var ( + defaultDevice Device + defaultDeviceSet bool + defaultStream Stream + defaultStreamSet bool +) + +func resetDefaultStreamCache() { + defaultDeviceSet = false + defaultStreamSet = false +} + +func DefaultDevice() Device { + if !defaultDeviceSet { + d := C.mlx_device_new() + C.mlx_get_default_device(&d) + defaultDevice = Device{d} + defaultDeviceSet = true + } + + return defaultDevice +} // GPUIsAvailable returns true if a GPU device is available. func GPUIsAvailable() bool { @@ -39,6 +53,7 @@ func SetDefaultDeviceGPU() { dev := C.mlx_device_new_type(C.MLX_GPU, 0) C.mlx_set_default_device(dev) C.mlx_device_free(dev) + resetDefaultStreamCache() } type Stream struct { @@ -52,8 +67,13 @@ func (s Stream) LogValue() slog.Value { return slog.StringValue(C.GoString(C.mlx_string_data(str))) } -var DefaultStream = sync.OnceValue(func() Stream { - s := C.mlx_stream_new() - C.mlx_get_default_stream(&s, DefaultDevice().ctx) - return Stream{s} -}) +func DefaultStream() Stream { + if !defaultStreamSet { + s := C.mlx_stream_new() + C.mlx_get_default_stream(&s, DefaultDevice().ctx) + defaultStream = Stream{s} + defaultStreamSet = true + } + + return defaultStream +} diff --git a/x/mlxrunner/mlx/thread_test.go b/x/mlxrunner/mlx/thread_test.go new file mode 100644 index 000000000..34d15ac42 --- /dev/null +++ b/x/mlxrunner/mlx/thread_test.go @@ -0,0 +1,104 @@ +package mlx + +import ( + "context" + "runtime" + "sync" + "testing" + + "github.com/ollama/ollama/x/internal/mlxthread" +) + +func skipIfNoMLX(t *testing.T) { + t.Helper() + if err := CheckInit(); err != nil { + t.Skipf("MLX not available: %v", err) + } +} + +func startMLXThread(t *testing.T) *mlxthread.Thread { + t.Helper() + + thread, err := mlxthread.Start("mlx-test", func() error { + if err := CheckInit(); err != nil { + return err + } + if GPUIsAvailable() { + SetDefaultDeviceGPU() + } + return nil + }) + if err != nil { + t.Skipf("MLX not available: %v", err) + } + + return thread +} + +func stopMLXThread(t *testing.T, thread *mlxthread.Thread) { + t.Helper() + + if err := thread.Stop(context.Background(), func() { + Sweep() + ClearCache() + resetDefaultStreamCache() + }); err != nil { + t.Fatal(err) + } +} + +func withMLXThread(t *testing.T, fn func()) { + t.Helper() + + thread := startMLXThread(t) + defer stopMLXThread(t, thread) + + if err := thread.Do(context.Background(), func() error { + fn() + return nil + }); err != nil { + t.Fatal(err) + } +} + +func TestThreadedMLXOperations(t *testing.T) { + thread := startMLXThread(t) + defer stopMLXThread(t, thread) + + oldProcs := runtime.GOMAXPROCS(8) + defer runtime.GOMAXPROCS(oldProcs) + + const goroutines = 8 + const iterations = 8 + + var wg sync.WaitGroup + errCh := make(chan error, goroutines) + for range goroutines { + wg.Add(1) + go func() { + defer wg.Done() + + for range iterations { + if err := thread.Do(context.Background(), func() error { + a := FromValues([]float32{1, 2, 3, 4}, 2, 2) + b := Matmul(a, a) + AsyncEval(b) + Eval(b) + Sweep() + ClearCache() + return nil + }); err != nil { + errCh <- err + return + } + } + }() + } + + wg.Wait() + close(errCh) + + for err := range errCh { + t.Fatal(err) + } +} diff --git a/x/mlxrunner/runner.go b/x/mlxrunner/runner.go index 6e23471b1..8f90c295e 100644 --- a/x/mlxrunner/runner.go +++ b/x/mlxrunner/runner.go @@ -11,6 +11,7 @@ import ( "golang.org/x/sync/errgroup" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/x/internal/mlxthread" "github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/model" "github.com/ollama/ollama/x/mlxrunner/model/base" @@ -39,6 +40,7 @@ type Runner struct { Sampler *sample.Sampler cache kvCache contextLength int + mlxThread *mlxthread.Thread } func (r *Runner) Load(modelName string) error { @@ -137,7 +139,8 @@ func (r *Runner) Run(host, port string, mux http.Handler) error { case <-ctx.Done(): return nil case request := <-r.Requests: - if err := request.Pipeline(request.Ctx, request); err != nil { + err := r.runRequest(request) + if err != nil { slog.Info("Request terminated", "error", err) var statusErr api.StatusError if !errors.As(err, &statusErr) { @@ -164,3 +167,13 @@ func (r *Runner) Run(host, port string, mux http.Handler) error { return g.Wait() } + +func (r *Runner) runRequest(request Request) error { + if r.mlxThread == nil { + return request.Pipeline(request.Ctx, request) + } + + return r.mlxThread.Do(request.Ctx, func() error { + return request.Pipeline(request.Ctx, request) + }) +} diff --git a/x/mlxrunner/server.go b/x/mlxrunner/server.go index d44de3d2b..c31d38167 100644 --- a/x/mlxrunner/server.go +++ b/x/mlxrunner/server.go @@ -15,6 +15,7 @@ import ( "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/logutil" + "github.com/ollama/ollama/x/internal/mlxthread" "github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/sample" ) @@ -22,17 +23,6 @@ import ( func Execute(args []string) error { slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel())) - if err := mlx.CheckInit(); err != nil { - return fmt.Errorf("MLX not available: %w", err) - } - - if mlx.GPUIsAvailable() { - mlx.SetDefaultDeviceGPU() - slog.Info("MLX engine initialized", "MLX version", mlx.Version(), "device", "gpu") - } else { - slog.Info("MLX engine initialized", "MLX version", mlx.Version(), "device", "cpu") - } - var ( modelName string port int @@ -44,21 +34,55 @@ func Execute(args []string) error { _ = flagSet.Bool("verbose", false, "Enable debug logging") flagSet.Parse(args) + worker, err := mlxthread.Start("mlxrunner", func() error { + if err := mlx.CheckInit(); err != nil { + return fmt.Errorf("MLX not available: %w", err) + } + + if mlx.GPUIsAvailable() { + mlx.SetDefaultDeviceGPU() + slog.Info("MLX engine initialized", "MLX version", mlx.Version(), "device", "gpu") + } else { + slog.Info("MLX engine initialized", "MLX version", mlx.Version(), "device", "cpu") + } + + return nil + }) + if err != nil { + return err + } + defer worker.Stop(context.Background(), func() { + mlx.Sweep() + mlx.ClearCache() + }) + runner := Runner{ - Requests: make(chan Request), + Requests: make(chan Request), + mlxThread: worker, } - if err := runner.Load(modelName); err != nil { + if err := worker.Do(context.Background(), func() error { + return runner.Load(modelName) + }); err != nil { return err } mux := http.NewServeMux() mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) { + memory, err := mlxthread.Call(r.Context(), worker, func() (uint64, error) { + return uint64(mlx.ActiveMemory() + mlx.CacheMemory()), nil + }) + if err != nil { + slog.Error("Failed to read MLX memory status", "error", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + if err := json.NewEncoder(w).Encode(statusResponse{ Status: 0, Progress: 100, ContextLength: runner.contextLength, - Memory: uint64(mlx.ActiveMemory() + mlx.CacheMemory()), + Memory: memory, }); err != nil { slog.Error("Failed to encode response", "error", err) http.Error(w, "Internal Server Error", http.StatusInternalServerError)