Rendering different content based on a discriminator field

When your app deals with heterogeneous lists of structured content that need to be rendered in the output differently based on each type, consider using rx.match with an explicit “discriminator” field in the object.

In the following example, I define a BaseMessage which automatically includes the (sub)class name in the type field. Then, define some subclasses of BaseMessage with different fields, annotations, and render method.

The render function is used to convert the model into Reflex components, and for subclasses it should work statically (in python land) or dynamically (in a foreach). The exception to this is BaseMessage.render which is a dispatch function that returns an rx.match over the message.type. Because of this, it must be called as an unbound method with a Var (from foreach) as the self argument.

Lets see how it looks

import reflex as rx


class BaseMessage(rx.Base):
    type: str = "BaseMessage"

    def dict(self, *args, **kwargs):
        return (
            super().dict(*args, **kwargs)
            | {"type": type(self).__name__}
        )

    def render(self) -> rx.Component | rx.Var:
        # Only deal with BaseMessage typed vars
        if not isinstance(self, rx.Var) or not issubclass(self._var_type, BaseMessage):
            raise TypeError(f"Cannot render {repr(self)}")

        # Dispatch the rendering to the appropriate subclass.
        return rx.match(
            self.type,
            *(
                (cls.__name__, cls.render(self.to(cls)))
                for cls in BaseMessage.__subclasses__()
            )
        )


class TextMessage(BaseMessage):
    text: str

    def render(self) -> rx.Component:
        return rx.text(self.text)


class TableMessage(BaseMessage):
    rows: list[dict[str, str]]

    def render(self) -> rx.Component:
        return rx.table.root(
            rx.table.header(
                rx.table.row(
                    rx.foreach(
                        self.rows[0],
                        lambda kv: rx.table.row_header_cell(kv[0])
                    ),
                ),
            ),
            rx.table.body(
                rx.foreach(
                    self.rows,
                    lambda row: rx.table.row(
                        rx.foreach(
                            row,
                            lambda kv: rx.table.cell(kv[1])
                        ),
                    ),
                ),
            ),
        )


class MarkdownMessage(BaseMessage):
    content: str

    def render(self) -> rx.Component:
        return rx.markdown(self.content)


class State(rx.State):
    """The app state."""
    messages: list[BaseMessage] = [
        TextMessage(text="I'm plain text!"),
        TableMessage(rows=[{"a": "1", "b": "2"}, {"a": "3", "b": "4"}]),
        MarkdownMessage(content="# This is markdown\n\n- you can have\n- bullet points\n\n```python\nprint('and code blocks')\n```"),
    ]


def index() -> rx.Component:
    return rx.container(
        rx.color_mode.button(position="top-right"),
        rx.vstack(
            rx.heading("Welcome to Message Dispatch!", size="9"),
            rx.separator(),
            rx.foreach(
                State.messages,
                lambda m: rx.card(BaseMessage.render(m), min_width="50vw")
            )
        ),
        rx.logo(),
    )


app = rx.App()
app.add_page(index)

3 Likes