Getting Started with MAX Engine C API

March 14, 2024

Ehsan M. Kermani

AI DevRel

In this blog post, we introduce the MAX Engine C API, gradually building awareness of its capabilities. The C API enables the integration of the MAX Engine into high-performance application code, facilitating running inference with models from PyTorch, TensorFlow and ONNX suitable for environment that does not require Python dependencies. We'll guide you through a minimal setup for utilizing the C API, from setting up the runtime context to applying it in text-image similarity tasks using the OpenAI CLIP model. In case you missed the getting started guide for MAX Engine Python API, please check out Getting Started with MAX Developer Edition.

It's important to consult the official documentation for detailed information. All code for this blog is available in our GitHub repository.

At a high level, working with MAX Engine for inference has three main steps

  1. Compiling the model
  2. Initializing the model
  3. Running the inference

With this framework in mind, let's begin by verifying the setup so you can follow along with the examples.

Hello, world!

The code for this blog post has been tested with MAX version 24.1.0-c176f84d-release. Please ensure your system meets the requirements and have MAX installed by following the Getting Started guide. Additionally, we require cmake version 3.24 or higher for this blog.

Verifying the setup and version

To familiarize ourselves with the API and ensure everything is functioning correctly, we'll start with the simplest case: retrieving the version using M_version(). This function is located in the

max/c/common.h header file and can be demonstrated as follows in basics/main.c

C
#include "max/c/common.h" #include <stdio.h> #include <stdlib.h> int main() { const char *version = M_version(); printf("MAX Engine version: %s\n", version); return EXIT_SUCCESS; }

Then in your CMakeList.txt include the following

cmake
cmake_minimum_required(VERSION 3.24) project(example LANGUAGES CXX C) list(APPEND CMAKE_MODULE_PATH "$ENV{MAX_PKG_DIR}/lib/cmake") include(AddMaxEngine) add_executable(basics main.c) target_link_libraries(basics PUBLIC max-engine)

And finally to compile and run, we do these steps

  1. set MAX path
bash
MAX_PKG_DIR="$(modular config max.path)" export MAX_PKG_DIR
  1. compile
bash
cmake -B build -S . cmake --build build
  1. execute
bash
./build/basics

which we should get the following output

Output
-- Configuring done (0.0s) -- Generating done (0.0s) [ 50%] Building C object CMakeFiles/basics.dir/main.c.o [100%] Linking C executable basics [100%] Built target basics MAX Engine version: 24.1.0-c176f84d-release

Now that we have verified the MAX Engine version 24.1.0-c176f84d-release, we can proceed to the next step.

Setting the runtime context

We are now ready to have a deeper look into the MAX Engine C API. Please note that all C APIs prefix with M_ as we saw in M_version. The first step in working with the MAX Engine is to establish the runtime context. This application-level object configures various resources such as thread pools and allocators for use during inference. It is advisable to create a single context and utilize it throughout your application. This can be accomplished by:

  1. Creating a status object that we will use throughout the application
  2. Configuring the runtime settings
  3. Initializing the runtime context with the specified configuration

For error handling, functions such as M_isError, M_getError can be utilized in conjunction with the previously created status object. Those functions are available in max/c/context.h header file.

C
M_Status *status = M_newStatus(); M_RuntimeConfig *runtimeConfig = M_newRuntimeConfig(); M_RuntimeContext *context = M_newRuntimeContext(runtimeConfig, status); if (M_isError(status)) { printf("Error: %s\n", M_getError(status)); return EXIT_FAILURE; } printf("Context is setup\n");

We can verify the setup by executing run.sh which should result in the message "Context is setup". The versatility of the C API also allows us to specify configurations, such as setting the number of threads via M_setNumThreads or determining the CPU affinity with M_getCPUAffinity.

C
M_setNumThreads(runtimeConfig, 1); size_t numThreads = M_getNumThreads(runtimeConfig); bool cpuAffinity = M_getCPUAffinity(runtimeConfig);

Now, let's explore a real-world application.

CLIP inference with MAX Engine

CLIP is a multi-modal deep learning model based on transformer architecture, developed by OpenAI. It is noteworthy for its capability in zero-shot image classification. In this blog, we utilize the CLIP model available on HuggingFace, which can be found here. Our focus will be on employing CLIP for text-image similarity task which takes text inputs as well image inputs and assigns similarity score for each text-image pair. The code for this example is located here.

Convert to TorchScript format

The first step is to convert the model to TorchScript format. We can do this by tracing the model with a dummy inputs that match the model input layer structure which for CLIP model are

  • input_ids of size (1, 77)
  • pixel_values of size (1, 3, 224, 224)
  • attention_mask of size (1, 77)
Python
from transformers import CLIPModel import torch model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") model.eval() model.config.return_dict = False inputs = { "input_ids": torch.ones(1, 77, dtype=torch.long), "pixel_values": torch.rand(1, 3, 224, 224), "attention_mask": torch.ones((1, 77), dtype=torch.long), } with torch.no_grad(): traced_model = torch.jit.trace(model, example_kwarg_inputs=dict(inputs), strict=False) traced_model.save("models/clip_vit.torchscript")

Pre-process inputs

For demonstration of the text-image similarity task, we use two text inputs 

  1. “a photo of a cat”
  2. “a photo of a dog" 

and we use the following image from the COCO dataset for demonstration.

We expect the model will identify the cats in the image, assigning a similarity score close to 1 (indicating high similarity) for the first input, and close to 0 (indicating low similarity) for the second input.

The preprocessing steps include downloading the image, configuring the CLIPProcessor, generating the input data, and ultimately converting the inputs into a binary format. For the sake of simplicity we will do this preprocessing in Python with the following script: it configures CLIPProcessor, processes our text and image inputs and saves raw tensors on disk. We can then load these saved tensors when we invoke our model from C

Python
from PIL import Image import requests from transformers import CLIPProcessor, CLIPModel processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") text=["a photo of a cat", "a photo of a dog"] url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw) inputs = processor(text=text, images=image, return_tensors="pt", padding=True)

Then we save the inputs as well as their shapes in binary format.

Output
inputs/ ├── attention_mask.bin ├── attention_mask_shape.bin ├── input_ids.bin ├── input_ids_shape.bin ├── pixel_values.bin └── pixel_values_shape.bin

Model compilation

We are now ready to compile the model, run the inference and save the result for further post-process. Given an existing runtime context, we proceed by

  1. Creating a compilation configuration object M_newCompileConfig and setting the model path.
C
M_CompileConfig *compileConfig = M_newCompileConfig(); M_setModelPath(compileConfig, /*path=*/modelPath);
  1. Specifying the input specs including data types and shape information for compilation. Note that this step is only required for TorchScript models.
C
int64_t *inputIdsShape = ... M_TorchInputSpec *inputIdsInputSpec = M_newTorchInputSpec(inputIdsShape, /*rankSize=*/2, /*dtype=*/M_INT64); int64_t *pixelValuesShape = ... M_TorchInputSpec *pixelValuesInputSpec = M_newTorchInputSpec(pixelValuesShape, /*rankSize=*/4, /*dtype=*/M_FLOAT32); int64_t *attentionMaskShape = ... M_TorchInputSpec *attentionMaskInputSpec = M_newTorchInputSpec( attentionMaskShape, /*rankSize=*/2, /*dtype=*/M_INT64); M_TorchInputSpec *inputSpecs[3] = {inputIdsInputSpec, pixelValuesInputSpec, attentionMaskInputSpec}; M_setTorchInputSpecs(compileConfig, inputSpecs, 3);
  1. Compiling the model via M_compileModel. Note that by default, MAX Engine compiles the model asynchronously; M_compileModel() returns immediately. An M_CompileConfig can only be used for a single compilation call. Any subsequent calls require a new M_CompileConfig.
C
M_AsyncCompiledModel *compiledModel = M_compileModel(context, &compileConfig, status);

Model initialization

The first three steps compile the model. In order to prepare the model for inference we need to initialize the model via M_initModel. The M_AsyncCompiledModel returned by M_compileModel is not ready for inference yet. We now need to initialize the model by calling M_initModel, which returns an instance of M_AsyncModel. This step prepares the compiled model for fast execution by running and initializing some of the graph operations that are input-independent. Since M_initModel is also asynchronous, it also returns immediately. If you want to wait for it to finish, add a call to M_waitForModel.

C
M_AsyncModel *model = M_initModel(context, compiledModel, status); M_waitForModel(model, status);

Run inference and obtain the results

We are now ready to run the inference. The first step is to

  1. prepare inputs. The inputs are passed in a TensorMap object, which can be created with  M_newAsyncTensorMap. Once this object is created, we can add our inputs to it. Since we operate on raw data pointers, we also need to create descriptors of the data pointed by it, which would specify the data type and shape of contained tensors. This is done by creating TensorSpecs. Note: while this might look identical to creating input specs for compilation, this serves a different purpose. At compilation we provide input specs to instruct the compiler to specialize the compiled model for specific shapes (the shapes don’t need to be static, some dimensions can be dynamic). At inference time we use input specs to simply describe the input values we’re going to pass for inference. We can add each input by calling M_borrowTensorInto, passing it the input tensor and the corresponding tensor specification (shape, type, etc) as an M_TensorSpec. Recall that the inputs attributes are as follows which we are going to use their ranks next
  • input_ids of size (1, 77)
  • pixel_values of size (1, 3, 224, 224)
  • attention_mask of size (1, 77)
C
M_AsyncTensorMap *inputToModel = M_newAsyncTensorMap(context); M_TensorSpec *inputIdsSpec = M_newTensorSpec(inputIdsShape, /*rankSize=*/2, /*dtype=*/M_INT64, /*tensorName=*/"input_ids"); int64_t *inputIdsTensor = ... M_borrowTensorInto(inputToModel, inputIdsTensor, inputIdsSpec, status); M_TensorSpec *pixelValuesSpec = M_newTensorSpec(pixelValuesShape, /*rankSize=*/4, /*dtype=*/M_FLOAT32, /*tensorName=*/"pixel_values"); float *pixelValuesTensor = ... M_borrowTensorInto(inputToModel, pixelValuesTensor, pixelValuesSpec, status); M_TensorSpec *attentionMaskSpec = M_newTensorSpec(attentionMaskShape, /*rankSize=*/2, /*dtype=*/M_INT64, /*tensorName=*/"attention_mask"); int64_t *attentionMaskTensor = ... M_borrowTensorInto(inputToModel, attentionMaskTensor, attentionMaskSpec, status);
  1. Execute the inference synchronously via M_executeModelSync.
C
M_AsyncTensorMap *outputs = M_executeModelSync(context, model, inputToModel, status);
  1. And finally getting the result value and saving for further post-processing. Also remember to free resources using M_free* APIs such as M_freeValue, M_freeModel etc.
C
M_AsyncValue *resultValue = M_getValueByNameFrom(outputs, /*tensorName=*/"result0", status); M_freeValue(result); M_freeModel(model);

Post-processing the result

The post-processing steps are straightforward and include the following

Python
import numpy as np import torch logits = torch.from_numpy(np.fromfile("outputs.bin", dtype=np.float32)) scores = logits.softmax(dim=-1) print(f"Scores: {scores.numpy()}")

which outputs

Output
[0.994858 0.00514201]

As a sanity check, we can further validate the output by running the equivalent script available in HuggingFace.

Python
from PIL import Image import requests from transformers import CLIPProcessor, CLIPModel model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw) inputs = processor(text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True ) outputs = model(**inputs) logits_per_image = outputs.logits_per_image probs = logits_per_image.softmax(dim=1) print(probs)

which should produce exactly the same output.

Output
[0.994858 0.00514201]

Conclusion

We hope that this introductory guide has equipped you with the foundational knowledge needed to effectively integrate the MAX Engine C API into your high-performance machine learning applications. By walking you through the essential steps of setting up the runtime context, preparing inputs, compiling the model, running inference, and post-processing the outputs, we aimed to demonstrate the seamless process of leveraging models from PyTorch, TensorFlow, and ONNX within MAX Engine. Notably, the inference executable in our CLIP example has a binary size of just 36KB, enabling direct inclusion in containers for serving without Python dependencies.  The power of the MAX Engine C API lies in its versatility and efficiency, making it an excellent choice for applications requiring robust performance and scalability. As you move forward, we encourage you to experiment with different models and configurations, explore the extensive documentation for deeper insights, and engage with the community for support and knowledge sharing.

Download MAX and try it out and share your feedback with us!

Until next time!🔥

Additional resources:

Report feedback, including issues on our Mojo and MAX GitHub tracker.

Ehsan M. Kermani
,
AI DevRel