December 3, 2023

Mojo 🔥 Traits Have Arrived!

Jack Clayton

Developer

Introduction

Strongly typed languages need some mechanism to express generic functions, other languages don't enforce types which means every function is generic. Many languages implement some form of this feature under a variety of different names:


Traits      : Rust, Scala, Crystal, Mojo 🔥
Interfaces  : Go, Java, CSharp, TypeScript, Kotlin, PHP, Dart
Protocols   : Swift, Objective-C, Elixir, Python
Typeclasses : Haskell
Concepts    : C++

Python

In dynamic languages like JavaScript and Python, you can pass any object to any function. But you must make sure the object implements the methods you call in that function, or you'll get a nasty runtime error:

Python

class Foo:
    pass

def cause_error(foo):
    foo.will_fail()

foo = Foo()
cause_error(foo)
Output

Traceback (most recent call last):
  File "main.py", line 8, in module
    cause_error(foo)
  File "main.py", line 5, in cause_error
    foo.will_fail()
    ^^^^^^^^^^^^^
AttributeError: 'Foo' object has no attribute 'will_fail'

Python3.8 introduced a typing feature named protocols which is related to traits, here's an example of a protocol named Shape that defines an abstract function signature area but doesn't implement it:

Python

from typing import Protocol

class Shape(Protocol):
    def area(self) -> float:
        ...

This is a similar concept to Mojo's traits in terms of syntax and concept, but it only gives linter warnings and some quality of life improvements like better completions. With Mojo, you also get no performance penalty for writing generic functions across different types, and you can't put mistakes into production as it simply won't compile.

Mojo 🔥

Let's take a look at Mojo traits:

Mojo

trait Shape:
    fn area(self) -> Float64:
        ...

We can now create a function that accepts anything implementing the Shape trait and run the abstract area method in the function:

Mojo

fn print_area[T: Shape](shape: T):
    print(shape.area())

The [T: Shape] at the start is common syntax across languages, you can think of T as declaring a generic Type that is constrained by the Shape trait.

Now we'll create a concrete fn area implementation on Circle:

Mojo

@value
struct Circle(Shape):
    var radius: Float64

    fn area(self) -> Float64:
        return 3.141592653589793 * self.radius ** 2

Circle(Shape) means that the struct Circle must implement all the methods specified in the Shape trait, in this case it's just the fn area(self) -> Float64 signature.

You can now run it through the generic function:

Mojo

let circle = Circle(radius=1.5)
print_area(circle)
Output

7.0685834705770345

If we remove fn area from Circle, the compiler won't allow us to build the program, so it's impossible to get a runtime error for this mistake. The other advantage in Mojo compared to Python, is the error tells us exactly what signature we need to implement:


/tmp/traits.mojo:11:1: error: struct 'Circle' does not implement all requirements for 'Shape'
struct Circle(Shape, Repr):
^
/tmp/traits.mojo:2:5: note: required function 'area' is not implemented
    fn area(self) -> Float64:
    ^
/tmp/traits.mojo:1:1: note: trait 'Shape' declared here

Lets create another type that implements the Shape trait:

Mojo

@value
struct Rectangle(Shape):
    var width: Float64
    var height: Float64

    fn area(self) -> Float64:
        return self.width * self.height

And run it through the same function:

Mojo

let rectangle = Rectangle(width=2, height=3)
print_area(rectangle)
Output

8

Truly Zero-Cost Generics

The popular pattern from C++ to achieve this behavior was using inheritance and abstract classes, but the compiler can't reason about what types are used when running methods on an abstract class, which can have significant performance impacts. The pattern can also explode a code base in complexity.

C++ added multiple features to address these problems, such as templates and concepts. But they still aren't completely zero-cost, while Mojo can guarantee that values are register passable when using traits for truly zero-cost generics.

Multiple Traits

The __str__ method comes from Python, it determines what will happen when you print() the type. We added a Stringable trait in the standard library, which you can implement it on your type to make it printable:

Mojo

trait Stringable:
    fn __str__(self) -> String:
        ...

Let's add the Stringable trait to Circle:

Mojo

@value
struct Circle(Shape, Stringable):
    var radius: Float64

    fn __str__(self) -> String:
        var res = String("Circle(radius=")
        res += self.radius
        res += ", area="
        res += self.area()
        res += ")"
        return res

    fn area(self) -> Float64:
        return 3.141592653589793 * self.radius ** 2

Circle now takes the Shape and Stringable traits, so it must implement fn area and fn __str__ to compile.

This allows us to print the type just like Python:

Mojo

let circle = Circle(radius=1.5)
print(circle)
Output

Circle(radius=1.5, area=7.0685834705770345)

Trait Inheritance

A really cool feature of traits, is it allows users to compose their types with your library. We've added a few simple traits to our standard library so you can inherit them into your own types.

Here's an example of creating a new trait that inherits Shape, along with the standard library Stringable and CollectionElement so that you can push your type into a DynamicVector. The requirements for CollectionElement are implemented when using the @value decorator.

Mojo

trait VecPrintableShape(CollectionElement, Stringable, Shape):
    ...

@value
struct Circle(VecPrintableShape):
    var radius: Float64

    fn __str__(self) -> String:
        var res = String("Circle(radius=")
        res += self.radius
        res += ", area="
        res += self.area()
        res += ")"
        return res

    fn area(self) -> Float64:
        return 3.141592653589793 * self.radius ** 2

fn main():
    var vec = DynamicVector[Circle](3)
    for i in range(0, 4):
        vec.push_back(Circle(i))
        print(vec[i])
Output

Circle(radius=0.0, area=0.0)
Circle(radius=1.0, area=3.1415926535897931)
Circle(radius=2.0, area=12.566370614359172)

Now that we have a trait that's composed our three traits together, we can create a function that makes use of all of them:

Mojo

fn print_vec[T: VecPrintable](vec: DynamicVector[T]):
    for i in range(len(vec)):
        print(vec[i])

print_vec[Circle](vec)
Output

Circle(radius=0.0, area=0.0)
Circle(radius=1.0, area=3.1415926535897931)
Circle(radius=2.0, area=12.566370614359172)

Try defining your own Rectangle type that implements VecPrintableShape!

Database Trait

If you're still struggling to understand why traits are useful, a common example used to demonstrate the utility is a Database trait. We'll only define two methods to simplify the concept:

Mojo

trait Database:
    fn insert(self, key: String, value: String) raises:
        ...

    fn print_all(self) raises:
        ...

Now we can pass around an object that implements Database and use it's abstract methods:

Mojo

fn insert_and_print[T: Database](db: T, key: String, value: String) raises:
    db.insert(key, value)
    db.print_all()

For example, imagine you have a function that runs inference on an image and stores the result somewhere. Maybe we just want to use SQLite for our local batch tests, but in production it'll be stored in some dynamodb instance. Or maybe we just want the flexibility to change out the database later without causing breaking changes for users of our library.

Lets import the sqlite3 Python package to implement a database that runs locally, and conforms to the trait above:

Mojo

struct SQLite(Database):
    var con: PythonObject
    var cur: PythonObject

    fn __init__(inout self, path: StringLiteral) raises:
        let sqlite3 = Python.import_module("sqlite3")
        self.con = sqlite3.connect(path)
        self.cur = self.con.cursor()
        _ = self.cur.execute("CREATE TABLE IF NOT EXISTS test (key TEXT, value TEXT)")

    fn insert(self, key: String, value: String) raises:
        var query = String("INSERT INTO test VALUES ('")
        query += key
        query += "', '"
        query += value
        query += "')"
        _ = self.cur.execute(query)

    fn print_all(self) raises:
        let query = String("SELECT * FROM test")
        var result = self.cur.execute(query)
        for row in result:
            print(row)

Now we can pass the database to our previous function:

Mojo

let db = SQLite("test.db")
insert_and_print(db, "name", "billy")

If you don't have sqlite installed, first run pip3 install sqlite3

Then run the program:

Python

('name', 'billy')

And now we can implement the Mongo version:

Mojo

struct Mongo(Database):
    var client: PythonObject
    var db: PythonObject
    var collection: PythonObject

    fn __init__(inout self, host: StringLiteral, port: Int, database: StringLiteral, collection: StringLiteral) raises:
        let MongoClient = Python.import_module("pymongo").MongoClient
        self.client = MongoClient(host, port)
        self.db = self.client[database]
        self.collection = self.db[collection]

    fn insert(self, key: String, value: String) raises:
        var document = String('{"')
        document += key
        document += '": "'
        document += value
        document += '"}'
        let my_document = Python.evaluate(document._strref_dangerous())
        _ = self.collection.insert_one(my_document)
        document._strref_keepalive()

    fn print_all(self) raises:
        var docs = self.collection.find()
        for doc in docs:
            print(doc)

You'll need to pip install pymongo if you want to run this, and follow the instructions here to start a service.

Then run it:

Mojo

let db2 = Mongo("localhost", 27017, "test", "test")
insert_and_print(db, "name", "billy")
Output

{'_id': ObjectId('655734c648d3297fa1b91b0b'), 'name': 'billy'}

This is simplified to demonstrate the functionality, but you could create an entire library following these principles wrapping Python libraries, and then introduce optimized Mojo implementations where you need better performance without changing the API.

Conclusion

You may have seen abstract methods navigating around Python code bases with ..., and not understood why they're there. It's pervasive in ML libraries where the authors want to provide correctness and nice tooling while still having multiple implementations for CUDA, CPU, and the many emerging hardware backends. You get an extra benefit in Mojo, you can write generic reusable functionality across multiple types, but still retain full type safety and compiler optimizations by writing concrete implementations for each type.

There are more features to come for traits such as default implementations, make sure to check back on the docs.

We're excited to see what you build with traits, please share your projects on the Discord and GitHub!

Jack Clayton
,
AI Developer Advocate

Jack Clayton

AI Developer Advocate

Jack started his career optimizing autonomous truck software for leading mining companies, including BHP and Caterpillar. Most recently he was designing computer vision software, putting AI inference pipelines into production for IDVerse. He is enormously passionate about the developer community, having been a Rust, Go, Python and C++ developer for over a decade. Jack enjoys making complicated topics simple and fun to learn, and he’s dedicated to teaching the world about Mojo 🔥.