Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/funny-peas-compare.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@ensembleui/react-runtime": patch
---

Fix: memoize conditional branch widgets
21 changes: 14 additions & 7 deletions packages/runtime/src/widgets/Conditional.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import type { Expression } from "@ensembleui/react-framework";
import type { EnsembleWidget, Expression } from "@ensembleui/react-framework";
import { unwrapWidget, useRegisterBindings } from "@ensembleui/react-framework";
import { cloneDeep, head, isEmpty, last } from "lodash-es";
import { useMemo } from "react";
import { useMemo, useState } from "react";
import { WidgetRegistry } from "../registry";
import { EnsembleRuntime } from "../runtime";
import type { EnsembleWidgetProps } from "../shared/types";
Expand All @@ -26,6 +26,7 @@ export const Conditional: React.FC<ConditionalProps> = ({
conditions,
...props
}) => {
const [matched, setMatched] = useState<{ [key: string]: unknown }>({});
const [isValid, errorMessage] = hasProperStructure(conditions);
if (!isValid) throw Error(errorMessage);

Expand Down Expand Up @@ -55,11 +56,17 @@ export const Conditional: React.FC<ConditionalProps> = ({
if (trueIndex === undefined || trueIndex < 0) {
return null;
}
const extractedWidget = extractWidget(conditions[trueIndex]);
return {
...extractedWidget,
key: conditionStatements[trueIndex]?.toString(),
};
const key = conditionStatements[trueIndex]?.toString();

const extractedWidget =
key && matched[key]
? (matched[key] as EnsembleWidget)
: extractWidget(conditions[trueIndex]);

if (key && !matched[key]) {
setMatched((prev) => ({ ...prev, [key]: extractedWidget }));
}
return { ...extractedWidget, key };
}, [conditionStatements, conditions, trueIndex]);

if (!widget) {
Expand Down
107 changes: 106 additions & 1 deletion packages/runtime/src/widgets/__tests__/Conditional.test.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
import { render, screen } from "@testing-library/react";
/* eslint import/first: 0 */
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
const framework = jest.requireActual("@ensembleui/react-framework");
// eslint-disable-next-line @typescript-eslint/no-unsafe-argument, @typescript-eslint/no-unsafe-member-access
const unwrapWidgetSpy = jest.fn().mockImplementation(framework.unwrapWidget);
import { fireEvent, render, screen } from "@testing-library/react";
import { BrowserRouter } from "react-router-dom";
import type { ConditionalProps, ConditionalElement } from "../Conditional";
import {
Conditional,
Expand All @@ -7,9 +13,20 @@ import {
extractCondition,
} from "../Conditional";
import "../index";
import { EnsembleScreen } from "../../runtime/screen";

jest.mock("react-markdown", jest.fn());

// eslint-disable-next-line @typescript-eslint/no-unsafe-return
jest.mock("@ensembleui/react-framework", () => ({
...framework,
unwrapWidget: unwrapWidgetSpy,
}));

afterEach(() => {
jest.clearAllMocks();
});

describe("Conditional Component", () => {
test('renders the widget when "if" condition is met', () => {
const conditionalProps: ConditionalProps = {
Expand Down Expand Up @@ -233,3 +250,91 @@ describe("extractCondition Function", () => {
expect(extractedCondition).toBe("1 === 1");
});
});

describe("conditional widget memoization", () => {
it("should memoize branch widgets and prevent unnecessary re-renders", () => {
render(
<EnsembleScreen
screen={{
name: "test_conditional",
id: "test_conditional",
body: {
name: "Column",
properties: {
children: [
{
name: "Conditional",
properties: {
conditions: [
{
if: `\${ensemble.storage.get('number') < 0}`,
Text: {
text: "Less than 0",
},
},
{
if: `\${ensemble.storage.get('number') === 0}`,
Text: {
text: "Equals to 0",
},
},
{
if: `\${ensemble.storage.get('number') > 0}`,
Text: {
text: "Greater than 0",
},
},
],
},
},
{
name: "Button",
properties: {
label: "Increase",
onTap: {
executeCode:
"ensemble.storage.set('number', ensemble.storage.get('number') + 1)",
},
},
},
{
name: "Button",
properties: {
label: "Decrease",
onTap: {
executeCode:
"ensemble.storage.set('number', ensemble.storage.get('number') - 1)",
},
},
},
],
},
},
onLoad: { executeCode: 'ensemble.storage.set("number", -1)' },
}}
/>,
{
wrapper: BrowserRouter,
},
);

expect(unwrapWidgetSpy).toHaveBeenCalledTimes(1);
expect(screen.getByText("Less than 0")).not.toBeNull();

fireEvent.click(screen.getByText("Increase"));
expect(unwrapWidgetSpy).toHaveBeenCalledTimes(2);
expect(screen.getByText("Equals to 0")).not.toBeNull();

fireEvent.click(screen.getByText("Increase"));
expect(unwrapWidgetSpy).toHaveBeenCalledTimes(3);
expect(screen.getByText("Greater than 0")).not.toBeNull();

fireEvent.click(screen.getByText("Decrease"));
expect(unwrapWidgetSpy).toHaveBeenCalledTimes(3);
expect(screen.getByText("Equals to 0")).not.toBeNull();

fireEvent.click(screen.getByText("Decrease"));
expect(unwrapWidgetSpy).toHaveBeenCalledTimes(3);
expect(screen.getByText("Less than 0")).not.toBeNull();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add assertion about how many times EnsembleRuntime.render was called?

});
});