Skip to content

Commit c18f88d

Browse files
committed
cleanup
1 parent 0c95dbc commit c18f88d

File tree

2 files changed

+122
-83
lines changed

2 files changed

+122
-83
lines changed

script/vm/compiler.lua

Lines changed: 110 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ local vm = require 'vm.vm'
1010
---@field _compiledNodes boolean
1111
---@field _node vm.node
1212
---@field _globalBase table
13+
---@field cindex integer
1314

1415
-- 该函数有副作用,会给source绑定node!
1516
local function bindDocs(source)
@@ -550,89 +551,20 @@ local function matchCall(source)
550551
end
551552
end
552553

553-
---@return vm.node?
554+
---@return vm.node
554555
local function getReturn(func, index, args)
555-
if func.special == 'setmetatable' then
556-
if not args then
557-
return nil
558-
end
559-
return getReturnOfSetMetaTable(args)
560-
end
561-
if func.special == 'pcall' and index > 1 then
562-
if not args then
563-
return nil
564-
end
565-
local newArgs = {}
566-
for i = 2, #args do
567-
newArgs[#newArgs+1] = args[i]
568-
end
569-
return getReturn(args[1], index - 1, newArgs)
570-
end
571-
if func.special == 'xpcall' and index > 1 then
572-
if not args then
573-
return nil
574-
end
575-
local newArgs = {}
576-
for i = 3, #args do
577-
newArgs[#newArgs+1] = args[i]
578-
end
579-
return getReturn(args[1], index - 1, newArgs)
580-
end
581-
if func.special == 'require' then
582-
if not args then
583-
return nil
584-
end
585-
local nameArg = args[1]
586-
if not nameArg or nameArg.type ~= 'string' then
587-
return nil
588-
end
589-
local name = nameArg[1]
590-
if not name or type(name) ~= 'string' then
591-
return nil
592-
end
593-
local uri = rpath.findUrisByRequirePath(guide.getUri(func), name)[1]
594-
if not uri then
595-
return nil
596-
end
597-
local state = files.getState(uri)
598-
local ast = state and state.ast
599-
if not ast then
600-
return nil
601-
end
602-
return vm.compileNode(ast)
556+
if not func._callReturns then
557+
func._callReturns = {}
603558
end
604-
local funcNode = vm.compileNode(func)
605-
---@type vm.node?
606-
local result
607-
for mfunc in funcNode:eachObject() do
608-
if mfunc.type == 'function'
609-
or mfunc.type == 'doc.type.function' then
610-
---@cast mfunc parser.object
611-
local returnObject = vm.getReturnOfFunction(mfunc, index)
612-
if returnObject then
613-
local returnNode = vm.compileNode(returnObject)
614-
for rnode in returnNode:eachObject() do
615-
if rnode.type == 'generic' then
616-
returnNode = rnode:resolve(guide.getUri(func), args)
617-
break
618-
end
619-
end
620-
if returnNode then
621-
for rnode in returnNode:eachObject() do
622-
-- TODO: narrow type
623-
if rnode.type ~= 'doc.generic.name' then
624-
result = result or vm.createNode()
625-
result:merge(rnode)
626-
end
627-
end
628-
if result and returnNode:isOptional() then
629-
result:addOptional()
630-
end
631-
end
632-
end
633-
end
559+
if not func._callReturns[index] then
560+
func._callReturns[index] = {
561+
type = 'call.return',
562+
parent = func,
563+
cindex = index,
564+
args = args,
565+
}
634566
end
635-
return result
567+
return vm.compileNode(func._callReturns[index])
636568
end
637569

638570
---@param source parser.object
@@ -765,13 +697,13 @@ function vm.selectNode(list, index)
765697
local result
766698
if exp.type == 'call' then
767699
result = getReturn(exp.node, index, exp.args)
768-
if not result then
769-
return vm.createNode(vm.declareGlobal('type', 'unknown')), exp
700+
if result:isEmpty() then
701+
result:merge(vm.declareGlobal('type', 'unknown'))
770702
end
771703
else
772704
---@type vm.node
773705
result = vm.compileNode(exp)
774-
if result and exp.type == 'varargs' and result:isEmpty() then
706+
if exp.type == 'varargs' and result:isEmpty() then
775707
result:merge(vm.declareGlobal('type', 'unknown'))
776708
end
777709
end
@@ -1596,6 +1528,101 @@ local compilerSwitch = util.switch()
15961528
vm.setNode(source, vm.declareGlobal('type', 'nil'))
15971529
end
15981530
end)
1531+
: case 'call.return'
1532+
---@param source parser.object
1533+
: call(function (source)
1534+
local func = source.parent
1535+
local args = source.args
1536+
local index = source.cindex
1537+
if func.special == 'setmetatable' then
1538+
if not args then
1539+
return
1540+
end
1541+
vm.setNode(source, getReturnOfSetMetaTable(args))
1542+
return
1543+
end
1544+
if func.special == 'pcall' and index > 1 then
1545+
if not args then
1546+
return
1547+
end
1548+
local newArgs = {}
1549+
for i = 2, #args do
1550+
newArgs[#newArgs+1] = args[i]
1551+
end
1552+
local node = getReturn(args[1], index - 1, newArgs)
1553+
if node then
1554+
vm.setNode(source, node)
1555+
end
1556+
return
1557+
end
1558+
if func.special == 'xpcall' and index > 1 then
1559+
if not args then
1560+
return
1561+
end
1562+
local newArgs = {}
1563+
for i = 3, #args do
1564+
newArgs[#newArgs+1] = args[i]
1565+
end
1566+
local node = getReturn(args[1], index - 1, newArgs)
1567+
if node then
1568+
vm.setNode(source, node)
1569+
end
1570+
return
1571+
end
1572+
if func.special == 'require' then
1573+
if not args then
1574+
return
1575+
end
1576+
local nameArg = args[1]
1577+
if not nameArg or nameArg.type ~= 'string' then
1578+
return
1579+
end
1580+
local name = nameArg[1]
1581+
if not name or type(name) ~= 'string' then
1582+
return
1583+
end
1584+
local uri = rpath.findUrisByRequirePath(guide.getUri(func), name)[1]
1585+
if not uri then
1586+
return
1587+
end
1588+
local state = files.getState(uri)
1589+
local ast = state and state.ast
1590+
if not ast then
1591+
return
1592+
end
1593+
vm.setNode(source, vm.compileNode(ast))
1594+
return
1595+
end
1596+
local funcNode = vm.compileNode(func)
1597+
---@type vm.node?
1598+
for mfunc in funcNode:eachObject() do
1599+
if mfunc.type == 'function'
1600+
or mfunc.type == 'doc.type.function' then
1601+
---@cast mfunc parser.object
1602+
local returnObject = vm.getReturnOfFunction(mfunc, index)
1603+
if returnObject then
1604+
local returnNode = vm.compileNode(returnObject)
1605+
for rnode in returnNode:eachObject() do
1606+
if rnode.type == 'generic' then
1607+
returnNode = rnode:resolve(guide.getUri(func), args)
1608+
break
1609+
end
1610+
end
1611+
if returnNode then
1612+
for rnode in returnNode:eachObject() do
1613+
-- TODO: narrow type
1614+
if rnode.type ~= 'doc.generic.name' then
1615+
vm.setNode(source, rnode)
1616+
end
1617+
end
1618+
if returnNode:isOptional() then
1619+
vm.getNode(source):addOptional()
1620+
end
1621+
end
1622+
end
1623+
end
1624+
end
1625+
end)
15991626
: case 'main'
16001627
: call(function (source)
16011628
if source.returns then

test/type_inference/init.lua

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3218,3 +3218,15 @@ if true then
32183218
end
32193219
local <?x?> = n or 0
32203220
]]
3221+
3222+
--TEST 'number' [=[
3223+
--local <?x?> = F()--[[@as number]]
3224+
--]=]
3225+
--
3226+
--TEST 'number' [=[
3227+
--local function f()
3228+
-- return F()--[[@as number]]
3229+
--end
3230+
--
3231+
--local <?x?> = f()
3232+
--]=]

0 commit comments

Comments
 (0)