diff --git a/apps/api/src/app.module.ts b/apps/api/src/app.module.ts index 3bc246e5ac..45756132cf 100644 --- a/apps/api/src/app.module.ts +++ b/apps/api/src/app.module.ts @@ -49,6 +49,7 @@ import { SecretsModule } from './secrets/secrets.module'; import { SecurityPenetrationTestsModule } from './security-penetration-tests/security-penetration-tests.module'; import { StripeModule } from './stripe/stripe.module'; import { AdminOrganizationsModule } from './admin-organizations/admin-organizations.module'; +import { FrameworkInstanceRequirementsModule } from './framework-instance-requirements/framework-instance-requirements.module'; @Module({ imports: [ @@ -110,6 +111,7 @@ import { AdminOrganizationsModule } from './admin-organizations/admin-organizati SecurityPenetrationTestsModule, StripeModule, AdminOrganizationsModule, + FrameworkInstanceRequirementsModule, ], controllers: [AppController], providers: [ diff --git a/apps/api/src/controls/controls.service.ts b/apps/api/src/controls/controls.service.ts index 7950c40416..79a6ad1fd2 100644 --- a/apps/api/src/controls/controls.service.ts +++ b/apps/api/src/controls/controls.service.ts @@ -20,6 +20,9 @@ const controlInclude = { requirement: { select: { name: true, identifier: true }, }, + frameworkInstanceRequirement: { + select: { name: true, identifier: true }, + }, }, }, } satisfies Prisma.ControlInclude; @@ -76,6 +79,7 @@ export class ControlsService { include: { framework: true }, }, requirement: true, + frameworkInstanceRequirement: true, }, }, }, @@ -117,41 +121,71 @@ export class ControlsService { } async getOptions(organizationId: string) { - const [policies, tasks, frameworkInstances] = await Promise.all([ - db.policy.findMany({ - where: { organizationId }, - select: { id: true, name: true }, - orderBy: { name: 'asc' }, - }), - db.task.findMany({ - where: { organizationId }, - select: { id: true, title: true }, - orderBy: { title: 'asc' }, - }), - db.frameworkInstance.findMany({ - where: { organizationId }, - include: { - framework: { - include: { - requirements: { - select: { id: true, name: true, identifier: true }, + const [policies, tasks, frameworkInstances, instanceRequirements] = + await Promise.all([ + db.policy.findMany({ + where: { organizationId }, + select: { id: true, name: true }, + orderBy: { name: 'asc' }, + }), + db.task.findMany({ + where: { organizationId }, + select: { id: true, title: true }, + orderBy: { title: 'asc' }, + }), + db.frameworkInstance.findMany({ + where: { organizationId }, + include: { + framework: { + include: { + requirements: { + select: { id: true, name: true, identifier: true }, + }, }, }, }, - }, - }), - ]); + }), + db.frameworkInstanceRequirement.findMany({ + where: { + frameworkInstance: { organizationId }, + }, + select: { + id: true, + name: true, + identifier: true, + frameworkInstanceId: true, + frameworkInstance: { + select: { + framework: { select: { name: true } }, + }, + }, + }, + orderBy: { name: 'asc' }, + }), + ]); - const requirements = frameworkInstances.flatMap((fi) => + const templateRequirements = frameworkInstances.flatMap((fi) => fi.framework.requirements.map((req) => ({ id: req.id, name: req.name, identifier: req.identifier, frameworkInstanceId: fi.id, frameworkName: fi.framework.name, + isInstanceRequirement: false, })), ); + const customRequirements = instanceRequirements.map((req) => ({ + id: req.id, + name: req.name, + identifier: req.identifier, + frameworkInstanceId: req.frameworkInstanceId, + frameworkName: req.frameworkInstance.framework.name, + isInstanceRequirement: true, + })); + + const requirements = [...templateRequirements, ...customRequirements]; + return { policies, tasks, requirements }; } @@ -184,8 +218,14 @@ export class ControlsService { db.requirementMap.create({ data: { controlId: control.id, - requirementId: mapping.requirementId, frameworkInstanceId: mapping.frameworkInstanceId, + ...(mapping.requirementId && { + requirementId: mapping.requirementId, + }), + ...(mapping.frameworkInstanceRequirementId && { + frameworkInstanceRequirementId: + mapping.frameworkInstanceRequirementId, + }), }, }), ), diff --git a/apps/api/src/controls/dto/create-control.dto.ts b/apps/api/src/controls/dto/create-control.dto.ts index b899ce6180..63d50ca546 100644 --- a/apps/api/src/controls/dto/create-control.dto.ts +++ b/apps/api/src/controls/dto/create-control.dto.ts @@ -9,9 +9,18 @@ import { import { Type } from 'class-transformer'; class RequirementMappingDto { - @ApiProperty({ description: 'Requirement ID' }) + @ApiProperty({ description: 'Template requirement ID', required: false }) + @IsOptional() + @IsString() + requirementId?: string; + + @ApiProperty({ + description: 'Instance requirement ID', + required: false, + }) + @IsOptional() @IsString() - requirementId: string; + frameworkInstanceRequirementId?: string; @ApiProperty({ description: 'Framework instance ID' }) @IsString() diff --git a/apps/api/src/framework-instance-requirements/dto/create-framework-instance-requirement.dto.ts b/apps/api/src/framework-instance-requirements/dto/create-framework-instance-requirement.dto.ts new file mode 100644 index 0000000000..2cafb97999 --- /dev/null +++ b/apps/api/src/framework-instance-requirements/dto/create-framework-instance-requirement.dto.ts @@ -0,0 +1,43 @@ +import { ApiProperty, ApiPropertyOptional } from '@nestjs/swagger'; +import { + IsArray, + IsString, + IsNotEmpty, + IsOptional, + MaxLength, +} from 'class-validator'; + +export class CreateFrameworkInstanceRequirementDto { + @ApiProperty({ example: 'frm_abc123' }) + @IsString() + @IsNotEmpty() + @MaxLength(255) + frameworkInstanceId: string; + + @ApiProperty({ example: 'Custom Access Control' }) + @IsString() + @IsNotEmpty() + @MaxLength(255) + name: string; + + @ApiPropertyOptional({ example: 'CUSTOM-1' }) + @IsString() + @IsOptional() + @MaxLength(255) + identifier?: string; + + @ApiProperty({ example: 'Custom requirement for access control policies' }) + @IsString() + @IsNotEmpty() + @MaxLength(5000) + description: string; + + @ApiPropertyOptional({ + description: 'Control IDs to link to this requirement', + type: [String], + }) + @IsOptional() + @IsArray() + @IsString({ each: true }) + controlIds?: string[]; +} diff --git a/apps/api/src/framework-instance-requirements/dto/update-framework-instance-requirement.dto.ts b/apps/api/src/framework-instance-requirements/dto/update-framework-instance-requirement.dto.ts new file mode 100644 index 0000000000..04e960416c --- /dev/null +++ b/apps/api/src/framework-instance-requirements/dto/update-framework-instance-requirement.dto.ts @@ -0,0 +1,22 @@ +import { ApiPropertyOptional } from '@nestjs/swagger'; +import { IsString, IsOptional, MaxLength } from 'class-validator'; + +export class UpdateFrameworkInstanceRequirementDto { + @ApiPropertyOptional() + @IsString() + @IsOptional() + @MaxLength(255) + name?: string; + + @ApiPropertyOptional() + @IsString() + @IsOptional() + @MaxLength(255) + identifier?: string; + + @ApiPropertyOptional() + @IsString() + @IsOptional() + @MaxLength(5000) + description?: string; +} diff --git a/apps/api/src/framework-instance-requirements/framework-instance-requirements.controller.ts b/apps/api/src/framework-instance-requirements/framework-instance-requirements.controller.ts new file mode 100644 index 0000000000..9e0ec9d162 --- /dev/null +++ b/apps/api/src/framework-instance-requirements/framework-instance-requirements.controller.ts @@ -0,0 +1,92 @@ +import { + Body, + Controller, + Delete, + Get, + Param, + Patch, + Post, + Query, + UseGuards, +} from '@nestjs/common'; +import { + ApiTags, + ApiBearerAuth, + ApiOperation, + ApiQuery, +} from '@nestjs/swagger'; +import { HybridAuthGuard } from '../auth/hybrid-auth.guard'; +import { PermissionGuard } from '../auth/permission.guard'; +import { RequirePermission } from '../auth/require-permission.decorator'; +import { OrganizationId } from '../auth/auth-context.decorator'; +import { FrameworkInstanceRequirementsService } from './framework-instance-requirements.service'; +import { CreateFrameworkInstanceRequirementDto } from './dto/create-framework-instance-requirement.dto'; +import { UpdateFrameworkInstanceRequirementDto } from './dto/update-framework-instance-requirement.dto'; + +@ApiTags('Framework Instance Requirements') +@ApiBearerAuth() +@UseGuards(HybridAuthGuard, PermissionGuard) +@Controller({ path: 'framework-instance-requirements', version: '1' }) +export class FrameworkInstanceRequirementsController { + constructor( + private readonly service: FrameworkInstanceRequirementsService, + ) {} + + @Get() + @RequirePermission('framework', 'read') + @ApiOperation({ + summary: 'List custom requirements for a framework instance', + }) + @ApiQuery({ name: 'frameworkInstanceId', required: true, type: String }) + async findAll( + @OrganizationId() organizationId: string, + @Query('frameworkInstanceId') frameworkInstanceId: string, + ) { + const data = await this.service.findAll( + frameworkInstanceId, + organizationId, + ); + return { data, count: data.length }; + } + + @Get(':id') + @RequirePermission('framework', 'read') + @ApiOperation({ summary: 'Get a single framework instance requirement' }) + async findOne( + @OrganizationId() organizationId: string, + @Param('id') id: string, + ) { + return this.service.findOne(id, organizationId); + } + + @Post() + @RequirePermission('framework', 'create') + @ApiOperation({ summary: 'Create a framework instance requirement' }) + async create( + @OrganizationId() organizationId: string, + @Body() dto: CreateFrameworkInstanceRequirementDto, + ) { + return this.service.create(dto, organizationId); + } + + @Patch(':id') + @RequirePermission('framework', 'update') + @ApiOperation({ summary: 'Update a framework instance requirement' }) + async update( + @OrganizationId() organizationId: string, + @Param('id') id: string, + @Body() dto: UpdateFrameworkInstanceRequirementDto, + ) { + return this.service.update(id, dto, organizationId); + } + + @Delete(':id') + @RequirePermission('framework', 'delete') + @ApiOperation({ summary: 'Delete a framework instance requirement' }) + async delete( + @OrganizationId() organizationId: string, + @Param('id') id: string, + ) { + return this.service.delete(id, organizationId); + } +} diff --git a/apps/api/src/framework-instance-requirements/framework-instance-requirements.module.ts b/apps/api/src/framework-instance-requirements/framework-instance-requirements.module.ts new file mode 100644 index 0000000000..6624971e3d --- /dev/null +++ b/apps/api/src/framework-instance-requirements/framework-instance-requirements.module.ts @@ -0,0 +1,12 @@ +import { Module } from '@nestjs/common'; +import { AuthModule } from '../auth/auth.module'; +import { FrameworkInstanceRequirementsController } from './framework-instance-requirements.controller'; +import { FrameworkInstanceRequirementsService } from './framework-instance-requirements.service'; + +@Module({ + imports: [AuthModule], + controllers: [FrameworkInstanceRequirementsController], + providers: [FrameworkInstanceRequirementsService], + exports: [FrameworkInstanceRequirementsService], +}) +export class FrameworkInstanceRequirementsModule {} diff --git a/apps/api/src/framework-instance-requirements/framework-instance-requirements.service.spec.ts b/apps/api/src/framework-instance-requirements/framework-instance-requirements.service.spec.ts new file mode 100644 index 0000000000..ed05909dab --- /dev/null +++ b/apps/api/src/framework-instance-requirements/framework-instance-requirements.service.spec.ts @@ -0,0 +1,251 @@ +import { NotFoundException } from '@nestjs/common'; +import { FrameworkInstanceRequirementsService } from './framework-instance-requirements.service'; + +// Mock the db module +jest.mock('@trycompai/db', () => ({ + db: { + frameworkInstance: { + findUnique: jest.fn(), + }, + frameworkInstanceRequirement: { + findMany: jest.fn(), + findUnique: jest.fn(), + create: jest.fn(), + update: jest.fn(), + delete: jest.fn(), + }, + }, +})); + +import { db } from '@trycompai/db'; + +const mockedDb = db as jest.Mocked; + +describe('FrameworkInstanceRequirementsService', () => { + let service: FrameworkInstanceRequirementsService; + + const orgId = 'org_1'; + const frameworkInstanceId = 'frm_1'; + const requirementId = 'fir_1'; + + beforeEach(() => { + service = new FrameworkInstanceRequirementsService(); + jest.clearAllMocks(); + }); + + describe('findAll', () => { + it('should return requirements for a valid framework instance', async () => { + const mockInstance = { id: frameworkInstanceId, organizationId: orgId }; + const mockRequirements = [ + { id: 'fir_1', name: 'Custom Req', requirementMaps: [] }, + ]; + + (mockedDb.frameworkInstance.findUnique as jest.Mock).mockResolvedValue( + mockInstance, + ); + ( + mockedDb.frameworkInstanceRequirement.findMany as jest.Mock + ).mockResolvedValue(mockRequirements); + + const result = await service.findAll(frameworkInstanceId, orgId); + + expect(mockedDb.frameworkInstance.findUnique).toHaveBeenCalledWith({ + where: { id: frameworkInstanceId, organizationId: orgId }, + }); + expect(result).toEqual(mockRequirements); + }); + + it('should throw NotFoundException if framework instance not found', async () => { + (mockedDb.frameworkInstance.findUnique as jest.Mock).mockResolvedValue( + null, + ); + + await expect( + service.findAll(frameworkInstanceId, orgId), + ).rejects.toThrow(NotFoundException); + }); + }); + + describe('findOne', () => { + it('should return a requirement if it belongs to the org', async () => { + const mockRequirement = { + id: requirementId, + name: 'Custom Req', + frameworkInstance: { organizationId: orgId }, + requirementMaps: [], + }; + + ( + mockedDb.frameworkInstanceRequirement.findUnique as jest.Mock + ).mockResolvedValue(mockRequirement); + + const result = await service.findOne(requirementId, orgId); + expect(result).toEqual(mockRequirement); + }); + + it('should throw NotFoundException if requirement not found', async () => { + ( + mockedDb.frameworkInstanceRequirement.findUnique as jest.Mock + ).mockResolvedValue(null); + + await expect(service.findOne(requirementId, orgId)).rejects.toThrow( + NotFoundException, + ); + }); + + it('should throw NotFoundException if requirement belongs to different org', async () => { + const mockRequirement = { + id: requirementId, + frameworkInstance: { organizationId: 'org_other' }, + }; + + ( + mockedDb.frameworkInstanceRequirement.findUnique as jest.Mock + ).mockResolvedValue(mockRequirement); + + await expect(service.findOne(requirementId, orgId)).rejects.toThrow( + NotFoundException, + ); + }); + }); + + describe('create', () => { + it('should create a requirement for a valid framework instance', async () => { + const dto = { + frameworkInstanceId, + name: 'New Requirement', + description: 'A custom requirement', + }; + const mockInstance = { id: frameworkInstanceId, organizationId: orgId }; + const mockCreated = { id: 'fir_new', ...dto, identifier: '' }; + + (mockedDb.frameworkInstance.findUnique as jest.Mock).mockResolvedValue( + mockInstance, + ); + ( + mockedDb.frameworkInstanceRequirement.create as jest.Mock + ).mockResolvedValue(mockCreated); + + const result = await service.create(dto, orgId); + + expect(result).toEqual(mockCreated); + expect( + mockedDb.frameworkInstanceRequirement.create, + ).toHaveBeenCalledWith({ + data: { + frameworkInstanceId, + name: dto.name, + identifier: '', + description: dto.description, + }, + }); + }); + + it('should throw NotFoundException if framework instance not found', async () => { + (mockedDb.frameworkInstance.findUnique as jest.Mock).mockResolvedValue( + null, + ); + + await expect( + service.create( + { + frameworkInstanceId, + name: 'Test', + description: 'Test', + }, + orgId, + ), + ).rejects.toThrow(NotFoundException); + }); + }); + + describe('update', () => { + it('should update a requirement belonging to the org', async () => { + const existing = { + id: requirementId, + frameworkInstance: { organizationId: orgId }, + }; + const updated = { id: requirementId, name: 'Updated Name' }; + + ( + mockedDb.frameworkInstanceRequirement.findUnique as jest.Mock + ).mockResolvedValue(existing); + ( + mockedDb.frameworkInstanceRequirement.update as jest.Mock + ).mockResolvedValue(updated); + + const result = await service.update( + requirementId, + { name: 'Updated Name' }, + orgId, + ); + + expect(result).toEqual(updated); + }); + + it('should throw NotFoundException if requirement belongs to different org', async () => { + const existing = { + id: requirementId, + frameworkInstance: { organizationId: 'org_other' }, + }; + + ( + mockedDb.frameworkInstanceRequirement.findUnique as jest.Mock + ).mockResolvedValue(existing); + + await expect( + service.update(requirementId, { name: 'Updated' }, orgId), + ).rejects.toThrow(NotFoundException); + }); + }); + + describe('delete', () => { + it('should delete a requirement belonging to the org', async () => { + const existing = { + id: requirementId, + frameworkInstance: { organizationId: orgId }, + }; + + ( + mockedDb.frameworkInstanceRequirement.findUnique as jest.Mock + ).mockResolvedValue(existing); + ( + mockedDb.frameworkInstanceRequirement.delete as jest.Mock + ).mockResolvedValue(existing); + + const result = await service.delete(requirementId, orgId); + + expect(result).toEqual({ + message: 'Framework instance requirement deleted successfully', + }); + expect( + mockedDb.frameworkInstanceRequirement.delete, + ).toHaveBeenCalledWith({ where: { id: requirementId } }); + }); + + it('should throw NotFoundException if requirement not found', async () => { + ( + mockedDb.frameworkInstanceRequirement.findUnique as jest.Mock + ).mockResolvedValue(null); + + await expect(service.delete(requirementId, orgId)).rejects.toThrow( + NotFoundException, + ); + }); + + it('should throw NotFoundException if requirement belongs to different org', async () => { + const existing = { + id: requirementId, + frameworkInstance: { organizationId: 'org_other' }, + }; + + ( + mockedDb.frameworkInstanceRequirement.findUnique as jest.Mock + ).mockResolvedValue(existing); + + await expect(service.delete(requirementId, orgId)).rejects.toThrow( + NotFoundException, + ); + }); + }); +}); diff --git a/apps/api/src/framework-instance-requirements/framework-instance-requirements.service.ts b/apps/api/src/framework-instance-requirements/framework-instance-requirements.service.ts new file mode 100644 index 0000000000..26aac1c2d7 --- /dev/null +++ b/apps/api/src/framework-instance-requirements/framework-instance-requirements.service.ts @@ -0,0 +1,172 @@ +import { Injectable, NotFoundException, Logger } from '@nestjs/common'; +import { db } from '@trycompai/db'; +import { CreateFrameworkInstanceRequirementDto } from './dto/create-framework-instance-requirement.dto'; +import { UpdateFrameworkInstanceRequirementDto } from './dto/update-framework-instance-requirement.dto'; + +@Injectable() +export class FrameworkInstanceRequirementsService { + private readonly logger = new Logger( + FrameworkInstanceRequirementsService.name, + ); + + async findAll(frameworkInstanceId: string, organizationId: string) { + const frameworkInstance = await db.frameworkInstance.findUnique({ + where: { id: frameworkInstanceId, organizationId }, + }); + + if (!frameworkInstance) { + throw new NotFoundException( + `Framework instance ${frameworkInstanceId} not found`, + ); + } + + return db.frameworkInstanceRequirement.findMany({ + where: { frameworkInstanceId }, + orderBy: { name: 'asc' }, + include: { + requirementMaps: { + include: { + control: { + include: { + tasks: true, + policies: true, + }, + }, + }, + }, + }, + }); + } + + async findOne(id: string, organizationId: string) { + const requirement = await db.frameworkInstanceRequirement.findUnique({ + where: { id }, + include: { + frameworkInstance: true, + requirementMaps: { + include: { + control: { + include: { + tasks: true, + policies: true, + }, + }, + }, + }, + }, + }); + + if (!requirement) { + throw new NotFoundException( + `Framework instance requirement ${id} not found`, + ); + } + + if (requirement.frameworkInstance.organizationId !== organizationId) { + throw new NotFoundException( + `Framework instance requirement ${id} not found`, + ); + } + + return requirement; + } + + async create( + dto: CreateFrameworkInstanceRequirementDto, + organizationId: string, + ) { + const frameworkInstance = await db.frameworkInstance.findUnique({ + where: { id: dto.frameworkInstanceId, organizationId }, + }); + + if (!frameworkInstance) { + throw new NotFoundException( + `Framework instance ${dto.frameworkInstanceId} not found`, + ); + } + + const requirement = await db.frameworkInstanceRequirement.create({ + data: { + frameworkInstanceId: dto.frameworkInstanceId, + name: dto.name, + identifier: dto.identifier ?? '', + description: dto.description, + }, + }); + + if (dto.controlIds && dto.controlIds.length > 0) { + await Promise.all( + dto.controlIds.map((controlId) => + db.requirementMap.create({ + data: { + controlId, + frameworkInstanceRequirementId: requirement.id, + frameworkInstanceId: dto.frameworkInstanceId, + }, + }), + ), + ); + } + + this.logger.log( + `Created framework instance requirement: ${requirement.name} (${requirement.id})`, + ); + return requirement; + } + + async update( + id: string, + dto: UpdateFrameworkInstanceRequirementDto, + organizationId: string, + ) { + const existing = await db.frameworkInstanceRequirement.findUnique({ + where: { id }, + include: { frameworkInstance: true }, + }); + + if (!existing) { + throw new NotFoundException( + `Framework instance requirement ${id} not found`, + ); + } + + if (existing.frameworkInstance.organizationId !== organizationId) { + throw new NotFoundException( + `Framework instance requirement ${id} not found`, + ); + } + + const updated = await db.frameworkInstanceRequirement.update({ + where: { id }, + data: dto, + }); + + this.logger.log( + `Updated framework instance requirement: ${updated.name} (${id})`, + ); + return updated; + } + + async delete(id: string, organizationId: string) { + const existing = await db.frameworkInstanceRequirement.findUnique({ + where: { id }, + include: { frameworkInstance: true }, + }); + + if (!existing) { + throw new NotFoundException( + `Framework instance requirement ${id} not found`, + ); + } + + if (existing.frameworkInstance.organizationId !== organizationId) { + throw new NotFoundException( + `Framework instance requirement ${id} not found`, + ); + } + + await db.frameworkInstanceRequirement.delete({ where: { id } }); + this.logger.log(`Deleted framework instance requirement ${id}`); + return { message: 'Framework instance requirement deleted successfully' }; + } +} diff --git a/apps/api/src/frameworks/frameworks.service.ts b/apps/api/src/frameworks/frameworks.service.ts index 4b69e33ec9..08d28eecc9 100644 --- a/apps/api/src/frameworks/frameworks.service.ts +++ b/apps/api/src/frameworks/frameworks.service.ts @@ -120,26 +120,47 @@ export class FrameworksService { const { requirementsMapped: _, ...rest } = fi; // Fetch additional data - const [requirementDefinitions, tasks, requirementMaps] = - await Promise.all([ - db.frameworkEditorRequirement.findMany({ - where: { frameworkId: fi.frameworkId }, - orderBy: { name: 'asc' }, - }), - db.task.findMany({ - where: { organizationId, controls: { some: { organizationId } } }, - include: { controls: true }, - }), - db.requirementMap.findMany({ - where: { frameworkInstanceId }, - include: { control: true }, - }), - ]); + const [ + requirementDefinitions, + frameworkInstanceRequirements, + tasks, + requirementMaps, + ] = await Promise.all([ + db.frameworkEditorRequirement.findMany({ + where: { frameworkId: fi.frameworkId }, + orderBy: { name: 'asc' }, + }), + db.frameworkInstanceRequirement.findMany({ + where: { frameworkInstanceId }, + include: { + requirementMaps: { + include: { + control: { + include: { + tasks: true, + policies: true, + }, + }, + }, + }, + }, + orderBy: { createdAt: 'asc' }, + }), + db.task.findMany({ + where: { organizationId, controls: { some: { organizationId } } }, + include: { controls: true }, + }), + db.requirementMap.findMany({ + where: { frameworkInstanceId }, + include: { control: true }, + }), + ]); return { ...rest, controls: Array.from(controlsMap.values()), requirementDefinitions, + frameworkInstanceRequirements, tasks, requirementMaps, }; @@ -206,36 +227,55 @@ export class FrameworksService { throw new NotFoundException('Framework instance not found'); } - const [allReqDefs, relatedControls, tasks] = await Promise.all([ - db.frameworkEditorRequirement.findMany({ - where: { frameworkId: fi.frameworkId }, - }), - db.requirementMap.findMany({ - where: { frameworkInstanceId, requirementId: requirementKey }, - include: { - control: { - include: { - policies: { - select: { id: true, name: true, status: true }, + const [allReqDefs, allInstanceReqs, relatedControls, tasks] = + await Promise.all([ + db.frameworkEditorRequirement.findMany({ + where: { frameworkId: fi.frameworkId }, + }), + db.frameworkInstanceRequirement.findMany({ + where: { frameworkInstanceId }, + }), + db.requirementMap.findMany({ + where: { + frameworkInstanceId, + OR: [ + { requirementId: requirementKey }, + { frameworkInstanceRequirementId: requirementKey }, + ], + }, + include: { + control: { + include: { + policies: { + select: { id: true, name: true, status: true }, + }, }, }, }, - }, - }), - db.task.findMany({ - where: { organizationId }, - include: { controls: true }, - }), - ]); + }), + db.task.findMany({ + where: { organizationId }, + include: { controls: true }, + }), + ]); + + // Look up in both template and instance requirements + const requirement = + allReqDefs.find((r) => r.id === requirementKey) ?? + allInstanceReqs.find((r) => r.id === requirementKey); - const requirement = allReqDefs.find((r) => r.id === requirementKey); if (!requirement) { throw new NotFoundException('Requirement not found'); } - const siblingRequirements = allReqDefs - .filter((r) => r.id !== requirementKey) - .map((r) => ({ id: r.id, name: r.name })); + // Siblings include both template and instance requirements + const allRequirements = [ + ...allReqDefs.map((r) => ({ id: r.id, name: r.name })), + ...allInstanceReqs.map((r) => ({ id: r.id, name: r.name })), + ]; + const siblingRequirements = allRequirements.filter( + (r) => r.id !== requirementKey, + ); return { requirement, diff --git a/apps/app/src/app/(app)/[orgId]/controls/[controlId]/components/RequirementsTable.tsx b/apps/app/src/app/(app)/[orgId]/controls/[controlId]/components/RequirementsTable.tsx index f0bee7cbdd..28d793d5cf 100644 --- a/apps/app/src/app/(app)/[orgId]/controls/[controlId]/components/RequirementsTable.tsx +++ b/apps/app/src/app/(app)/[orgId]/controls/[controlId]/components/RequirementsTable.tsx @@ -4,6 +4,7 @@ import type { FrameworkEditorFramework, FrameworkEditorRequirement, FrameworkInstance, + FrameworkInstanceRequirement, RequirementMap, } from '@db'; import { InputGroup, InputGroupAddon, InputGroupInput } from '@trycompai/design-system'; @@ -25,34 +26,57 @@ interface RequirementsTableProps { frameworkInstance: FrameworkInstance & { framework: FrameworkEditorFramework; }; - requirement: FrameworkEditorRequirement; + requirement: FrameworkEditorRequirement | null; + frameworkInstanceRequirement?: FrameworkInstanceRequirement | null; })[]; orgId: string; } +function getRequirementData(req: RequirementsTableProps['requirements'][number]) { + if (req.requirement) { + return { + id: req.requirement.id, + name: req.requirement.name, + description: req.requirement.description, + identifier: req.requirement.identifier, + }; + } + if (req.frameworkInstanceRequirement) { + return { + id: req.frameworkInstanceRequirement.id, + name: req.frameworkInstanceRequirement.name, + description: req.frameworkInstanceRequirement.description, + identifier: req.frameworkInstanceRequirement.identifier, + }; + } + return null; +} + export function RequirementsTable({ requirements, orgId }: RequirementsTableProps) { const router = useRouter(); const [searchTerm, setSearchTerm] = useState(''); - // Filter requirements data based on search term const filteredRequirements = useMemo(() => { if (!searchTerm.trim()) return requirements; const searchLower = searchTerm.toLowerCase(); return requirements.filter((req) => { - // Search in ID, name, and description from the nested requirement object + const data = getRequirementData(req); + if (!data) return false; return ( - (req.requirement.id?.toLowerCase() || '').includes(searchLower) || - (req.requirement.name?.toLowerCase() || '').includes(searchLower) || - (req.requirement.description?.toLowerCase() || '').includes(searchLower) || - (req.requirement.identifier?.toLowerCase() || '').includes(searchLower) // Also search identifier + (data.id?.toLowerCase() || '').includes(searchLower) || + (data.name?.toLowerCase() || '').includes(searchLower) || + (data.description?.toLowerCase() || '').includes(searchLower) || + (data.identifier?.toLowerCase() || '').includes(searchLower) ); }); }, [requirements, searchTerm]); - const handleRowClick = (requirement: RequirementMap) => { + const handleRowClick = (req: RequirementsTableProps['requirements'][number]) => { + const data = getRequirementData(req); + if (!data) return; router.push( - `/${orgId}/frameworks/${requirement.frameworkInstanceId}/requirements/${requirement.requirementId}`, + `/${orgId}/frameworks/${req.frameworkInstanceId}/requirements/${data.id}`, ); }; @@ -88,31 +112,35 @@ export function RequirementsTable({ requirements, orgId }: RequirementsTableProp ) : ( - filteredRequirements.map((requirement) => ( - handleRowClick(requirement)} - onKeyDown={(event) => { - if (event.key === 'Enter' || event.key === ' ') { - event.preventDefault(); - handleRowClick(requirement); - } - }} - > - - - {requirement.requirement.name} - - - - - {requirement.requirement.description} - - - - )) + filteredRequirements.map((requirement) => { + const data = getRequirementData(requirement); + if (!data) return null; + return ( + handleRowClick(requirement)} + onKeyDown={(event) => { + if (event.key === 'Enter' || event.key === ' ') { + event.preventDefault(); + handleRowClick(requirement); + } + }} + > + + + {data.name} + + + + + {data.description} + + + + ); + }) )} diff --git a/apps/app/src/app/(app)/[orgId]/controls/components/CreateControlSheet.tsx b/apps/app/src/app/(app)/[orgId]/controls/components/CreateControlSheet.tsx index afdf4858e2..f8764e8e41 100644 --- a/apps/app/src/app/(app)/[orgId]/controls/components/CreateControlSheet.tsx +++ b/apps/app/src/app/(app)/[orgId]/controls/components/CreateControlSheet.tsx @@ -29,7 +29,8 @@ const createControlSchema = z.object({ requirementMappings: z .array( z.object({ - requirementId: z.string(), + requirementId: z.string().optional(), + frameworkInstanceRequirementId: z.string().optional(), frameworkInstanceId: z.string(), }), ) @@ -49,6 +50,7 @@ export function CreateControlSheet({ identifier: string; frameworkInstanceId: string; frameworkName: string; + isInstanceRequirement?: boolean; }[]; }) { const { createControl } = useControls(); @@ -116,6 +118,7 @@ export function CreateControlSheet({ value: req.id, label: `${req.frameworkName}: ${req.identifier} - ${req.name}`, frameworkInstanceId: req.frameworkInstanceId, + isInstanceRequirement: req.isInstanceRequirement ?? false, })), [requirements], ); @@ -158,9 +161,14 @@ export function CreateControlSheet({ ); const handleRequirementsChange = useCallback( - (options: (Option & { frameworkInstanceId?: string })[], onChange: (value: any) => void) => { + ( + options: (Option & { frameworkInstanceId?: string; isInstanceRequirement?: boolean })[], + onChange: (value: any) => void, + ) => { const mappings = options.map((option) => ({ - requirementId: option.value, + ...(option.isInstanceRequirement + ? { frameworkInstanceRequirementId: option.value } + : { requirementId: option.value }), frameworkInstanceId: option.frameworkInstanceId || '', })); onChange(mappings); @@ -290,20 +298,27 @@ export function CreateControlSheet({ control={form.control} name="requirementMappings" render={({ field }) => { - const selectedOptions: (Option & { frameworkInstanceId?: string })[] = ( - field.value || [] - ) + const selectedOptions: (Option & { + frameworkInstanceId?: string; + isInstanceRequirement?: boolean; + })[] = (field.value || []) .map((mapping) => { - const req = requirements.find((r) => r.id === mapping.requirementId); + const reqId = + mapping.requirementId ?? mapping.frameworkInstanceRequirementId; + const req = requirements.find((r) => r.id === reqId); return req ? { value: req.id, label: `${req.frameworkName}: ${req.identifier} - ${req.name}`, frameworkInstanceId: req.frameworkInstanceId, + isInstanceRequirement: req.isInstanceRequirement ?? false, } : null; }) - .filter(Boolean) as (Option & { frameworkInstanceId?: string })[]; + .filter(Boolean) as (Option & { + frameworkInstanceId?: string; + isInstanceRequirement?: boolean; + })[]; return ( @@ -314,12 +329,18 @@ export function CreateControlSheet({ value={selectedOptions} onChange={(options) => handleRequirementsChange( - options as (Option & { frameworkInstanceId?: string })[], + options as (Option & { + frameworkInstanceId?: string; + isInstanceRequirement?: boolean; + })[], field.onChange, ) } defaultOptions={ - requirementOptions as (Option & { frameworkInstanceId?: string })[] + requirementOptions as (Option & { + frameworkInstanceId?: string; + isInstanceRequirement?: boolean; + })[] } placeholder="Search and select requirements..." emptyIndicator={ diff --git a/apps/app/src/app/(app)/[orgId]/controls/hooks/useControls.ts b/apps/app/src/app/(app)/[orgId]/controls/hooks/useControls.ts index b46df5eedd..a4961548d9 100644 --- a/apps/app/src/app/(app)/[orgId]/controls/hooks/useControls.ts +++ b/apps/app/src/app/(app)/[orgId]/controls/hooks/useControls.ts @@ -15,7 +15,8 @@ interface CreateControlPayload { policyIds?: string[]; taskIds?: string[]; requirementMappings?: { - requirementId: string; + requirementId?: string; + frameworkInstanceRequirementId?: string; frameworkInstanceId: string; }[]; } diff --git a/apps/app/src/app/(app)/[orgId]/frameworks/[frameworkInstanceId]/components/CreateRequirementSheet.tsx b/apps/app/src/app/(app)/[orgId]/frameworks/[frameworkInstanceId]/components/CreateRequirementSheet.tsx new file mode 100644 index 0000000000..d9a7dd3c83 --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/frameworks/[frameworkInstanceId]/components/CreateRequirementSheet.tsx @@ -0,0 +1,230 @@ +'use client'; + +import { apiClient } from '@/lib/api-client'; +import { useMediaQuery } from '@trycompai/ui/hooks'; +import { Button } from '@trycompai/ui/button'; +import { + Combobox, + ComboboxChip, + ComboboxChips, + ComboboxChipsInput, + ComboboxContent, + ComboboxItem, + ComboboxList, + Drawer, + DrawerContent, + DrawerHeader, + DrawerTitle, + Field, + FieldError, + FieldGroup, + FieldLabel, + HStack, + Input, + Sheet, + SheetBody, + SheetContent, + SheetFooter, + SheetHeader, + SheetTitle, + Textarea, + useComboboxAnchor, +} from '@trycompai/design-system'; +import { ArrowRight } from '@trycompai/design-system/icons'; +import { zodResolver } from '@hookform/resolvers/zod'; +import { useMemo, useState } from 'react'; +import { Controller, useForm } from 'react-hook-form'; +import { toast } from 'sonner'; +import { z } from 'zod'; + +const createRequirementSchema = z.object({ + name: z.string().min(1, { message: 'Name is required' }), + identifier: z.string().optional(), + description: z.string().min(1, { message: 'Description is required' }), + controlIds: z.array(z.string()).optional(), +}); + +interface CreateRequirementSheetProps { + open: boolean; + onOpenChange: (open: boolean) => void; + frameworkInstanceId: string; + onCreated: () => void; + controls: { id: string; name: string }[]; +} + +export function CreateRequirementSheet({ + open, + onOpenChange, + frameworkInstanceId, + onCreated, + controls, +}: CreateRequirementSheetProps) { + const isDesktop = useMediaQuery('(min-width: 768px)'); + const [isSubmitting, setIsSubmitting] = useState(false); + + const { + register, + handleSubmit, + reset, + control, + formState: { errors }, + } = useForm>({ + resolver: zodResolver(createRequirementSchema), + defaultValues: { + name: '', + identifier: '', + description: '', + controlIds: [], + }, + }); + + const onSubmit = async (data: z.infer) => { + setIsSubmitting(true); + try { + await apiClient.post('/v1/framework-instance-requirements', { + ...data, + frameworkInstanceId, + }); + toast.success('Requirement created'); + onOpenChange(false); + reset(); + onCreated(); + } catch { + toast.error('Failed to create requirement'); + } finally { + setIsSubmitting(false); + } + }; + + const requirementForm = ( +
+ + + Requirement Name + + + + + + Identifier (Optional) + + + + + + Description +