From 092cb0a04ef4db8230bf781fbdc2a01d74a0c6a3 Mon Sep 17 00:00:00 2001 From: Alexander Heistermann Date: Thu, 2 Jul 2026 18:04:17 -0500 Subject: [PATCH] Add Lua trace ray functions (#1624) * adds `Spring.TraceRayBetweenPositions(xA, yA, zA, xB, yB, zB, type)` * adds `Spring.TraceRayInDirection(x, y, z, dx, dy, dz, length, type)` * type is a string, "unit", "feature", or "both" * both return an array of `{distance, objID, objType}` sorted by increasing distance --- rts/Lua/LuaSyncedRead.cpp | 138 ++++++++++++++++++++++++++++++++++++++ rts/Lua/LuaSyncedRead.h | 5 +- 2 files changed, 140 insertions(+), 3 deletions(-) diff --git a/rts/Lua/LuaSyncedRead.cpp b/rts/Lua/LuaSyncedRead.cpp index 40b25741b52..3983a71f3ff 100644 --- a/rts/Lua/LuaSyncedRead.cpp +++ b/rts/Lua/LuaSyncedRead.cpp @@ -399,6 +399,8 @@ bool LuaSyncedRead::PushEntries(lua_State* L) REGISTER_LUA_CFUNC(TraceRayGroundInDirection); REGISTER_LUA_CFUNC(TraceRayGroundBetweenPositions); + REGISTER_LUA_CFUNC(TraceRayInDirection); + REGISTER_LUA_CFUNC(TraceRayBetweenPositions); REGISTER_LUA_CFUNC(GetRadarErrorParams); @@ -9011,6 +9013,142 @@ int LuaSyncedRead::GetUnitScriptNames(lua_State* L) return 1; } + +static int TraceRayImpl(lua_State *const L, const float3 &pos, const float3 &dir, const float maxLen, std::string_view type) +{ + if (type != "unit" && type != "feature" && type != "both") + return luaL_error(L, "invalid type '%s', expected 'unit', 'feature', or 'both'", type.data()); + + const bool testUnits = (type == "unit" || type == "both"); + const bool testFeatures = (type == "feature" || type == "both"); + + QuadFieldQuery qfQuery; + quadField.GetQuadsOnRay(qfQuery, pos, dir, maxLen); + + spring::unordered_set testedUnitIDs; + spring::unordered_set testedFeatureIDs; + std::vector > hits; + + for (const int quadIdx : *qfQuery.quads) { + const CQuadField::Quad& quad = quadField.GetQuad(quadIdx); + + if (testUnits) { + for (const auto *unit : quad.units) { + if (!unit->HasCollidableStateBit(CSolidObject::CSTATE_BIT_QUADMAPRAYS)) + continue; + + if (!testedUnitIDs.insert(unit->id).second) + continue; + + if (!LuaUtils::IsUnitInLos(L, unit)) + continue; + + CollisionQuery cq; + if (CCollisionHandler::DetectHit(unit, unit->GetTransformMatrix(true), pos, pos + dir * maxLen, &cq, true)) { + const float len = cq.GetHitPosDist(pos, dir); + if (len > maxLen) // possibly a bug in CCollisionHandler::DetectHit? + continue; + hits.emplace_back(len, unit->id, "unit"); + } + } + } + + if (testFeatures) { + for (const auto *feature : quad.features) { + if (!feature->HasCollidableStateBit(CSolidObject::CSTATE_BIT_QUADMAPRAYS)) + continue; + + if (!testedFeatureIDs.insert(feature->id).second) + continue; + + if (!LuaUtils::IsFeatureVisible(L, feature)) + continue; + + CollisionQuery cq; + if (CCollisionHandler::DetectHit(feature, feature->GetTransformMatrix(true), pos, pos + dir * maxLen, &cq, true)) { + const float len = cq.GetHitPosDist(pos, dir); + if (len > maxLen) + continue; + hits.emplace_back(len, feature->id, "feature"); + } + } + } + } + + std::stable_sort(hits.begin(), hits.end(), [] (const auto& a, const auto& b) { + return std::get<0>(a) < std::get<0>(b); + }); + + lua_createtable(L, hits.size(), 0); + + int num = 0; + for (const auto& [hitLength, objectID, objectType] : hits) { + lua_createtable(L, 3, 0); + + lua_pushnumber(L, hitLength); + lua_rawseti(L, -2, 1); + lua_pushnumber(L, objectID); + lua_rawseti(L, -2, 2); + lua_pushstring(L, objectType); + lua_rawseti(L, -2, 3); + + lua_rawseti(L, -2, ++num); + } + + return 1; +} + +/*** Traces a ray from a position in a direction + * + * @function Spring.TraceRayInDirection + * + * Returns all unit and/or feature hits along a ray, sorted by distance + * from the start position. + * + * @param posX number + * @param posY number + * @param posZ number + * @param dirX number + * @param dirY number + * @param dirZ number + * @param maxLength number + * @param type string Object type to test: `"unit"`, `"feature"`, or `"both"` + * @return table[] hits Array of `{hitLength, objectID, objectType}` entries + */ +int LuaSyncedRead::TraceRayInDirection(lua_State* L) +{ + float3 pos(luaL_checkfloat(L, 1), luaL_checkfloat(L, 2), luaL_checkfloat(L, 3)); + float3 dir(luaL_checkfloat(L, 4), luaL_checkfloat(L, 5), luaL_checkfloat(L, 6)); + const float maxLen = luaL_optfloat(L, 7, 999999.f); + const char* type = luaL_checkstring(L, 8); + return TraceRayImpl(L, pos, dir, maxLen, type); +} + +/*** Traces a ray between two positions + * + * @function Spring.TraceRayBetweenPositions + * + * Checks for unit and/or feature collisions between two positions + * and returns all hits sorted by distance from the start position. + * + * @param startX number + * @param startY number + * @param startZ number + * @param endX number + * @param endY number + * @param endZ number + * @param type string Object type to test: `"unit"`, `"feature"`, or `"both"` + * @return table[] hits Array of `{hitLength, objectID, objectType}` entries + */ +int LuaSyncedRead::TraceRayBetweenPositions(lua_State* L) +{ + float3 start(luaL_checkfloat(L, 1), luaL_checkfloat(L, 2), luaL_checkfloat(L, 3)); + float3 end(luaL_checkfloat(L, 4), luaL_checkfloat(L, 5), luaL_checkfloat(L, 6)); + const char* type = luaL_checkstring(L, 7); + const auto [dir, length] = (end - start).GetNormalized(); + return TraceRayImpl(L, start, dir, length, type); +} + static int TraceRayGroundImpl(lua_State *const L, const float3 &pos, const float3 &dir, const float maxLen, const bool testWater) { const float rayLength = CGround::LineGroundWaterCol(pos, dir, maxLen, testWater, CLuaHandle::GetHandleSynced(L)); diff --git a/rts/Lua/LuaSyncedRead.h b/rts/Lua/LuaSyncedRead.h index 98c632ad9b0..dbec13c0ab7 100644 --- a/rts/Lua/LuaSyncedRead.h +++ b/rts/Lua/LuaSyncedRead.h @@ -302,9 +302,8 @@ class LuaSyncedRead { static int GetRadarErrorParams(lua_State* L); - static int TraceRay(lua_State* L); //TODO: not implemented - static int TraceRayUnits(lua_State* L); //TODO: not implemented - static int TraceRayFeatures(lua_State* L); //TODO: not implemented + static int TraceRayInDirection(lua_State* L); + static int TraceRayBetweenPositions(lua_State* L); static int TraceRayGroundBetweenPositions(lua_State* L); static int TraceRayGroundInDirection(lua_State* L); };