Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions rts/Lua/LuaSyncedRead.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,8 @@ bool LuaSyncedRead::PushEntries(lua_State* L)

REGISTER_LUA_CFUNC(TraceRayGroundInDirection);
REGISTER_LUA_CFUNC(TraceRayGroundBetweenPositions);
REGISTER_LUA_CFUNC(TraceRayInDirection);
REGISTER_LUA_CFUNC(TraceRayBetweenPositions);

REGISTER_LUA_CFUNC(GetRadarErrorParams);

Expand Down Expand Up @@ -9011,6 +9013,142 @@ int LuaSyncedRead::GetUnitScriptNames(lua_State* L)
return 1;
}


static int TraceRayImpl(lua_State *const L, const float3 &pos, const float3 &dir, const float maxLen, std::string_view type)
{
if (type != "unit" && type != "feature" && type != "both")
return luaL_error(L, "invalid type '%s', expected 'unit', 'feature', or 'both'", type.data());

const bool testUnits = (type == "unit" || type == "both");
const bool testFeatures = (type == "feature" || type == "both");

QuadFieldQuery qfQuery;
quadField.GetQuadsOnRay(qfQuery, pos, dir, maxLen);

spring::unordered_set <int> testedUnitIDs;
spring::unordered_set <int> testedFeatureIDs;
std::vector <std::tuple<float, int, const char*>> hits;

for (const int quadIdx : *qfQuery.quads) {
const CQuadField::Quad& quad = quadField.GetQuad(quadIdx);

if (testUnits) {
for (const auto *unit : quad.units) {
if (!unit->HasCollidableStateBit(CSolidObject::CSTATE_BIT_QUADMAPRAYS))
continue;

if (!testedUnitIDs.insert(unit->id).second)
continue;

if (!LuaUtils::IsUnitInLos(L, unit))
continue;

CollisionQuery cq;
if (CCollisionHandler::DetectHit(unit, unit->GetTransformMatrix(true), pos, pos + dir * maxLen, &cq, true)) {
const float len = cq.GetHitPosDist(pos, dir);
if (len > maxLen) // possibly a bug in CCollisionHandler::DetectHit?
continue;
hits.emplace_back(len, unit->id, "unit");
}
}
}

if (testFeatures) {
for (const auto *feature : quad.features) {
if (!feature->HasCollidableStateBit(CSolidObject::CSTATE_BIT_QUADMAPRAYS))
continue;

if (!testedFeatureIDs.insert(feature->id).second)
continue;

if (!LuaUtils::IsFeatureVisible(L, feature))
continue;

CollisionQuery cq;
if (CCollisionHandler::DetectHit(feature, feature->GetTransformMatrix(true), pos, pos + dir * maxLen, &cq, true)) {
const float len = cq.GetHitPosDist(pos, dir);
if (len > maxLen)
continue;
hits.emplace_back(len, feature->id, "feature");
}
}
}
}

std::stable_sort(hits.begin(), hits.end(), [] (const auto& a, const auto& b) {
return std::get<0>(a) < std::get<0>(b);
});

lua_createtable(L, hits.size(), 0);

int num = 0;
for (const auto& [hitLength, objectID, objectType] : hits) {
lua_createtable(L, 3, 0);

lua_pushnumber(L, hitLength);
lua_rawseti(L, -2, 1);
lua_pushnumber(L, objectID);
lua_rawseti(L, -2, 2);
lua_pushstring(L, objectType);
lua_rawseti(L, -2, 3);

lua_rawseti(L, -2, ++num);
}

return 1;
}

/*** Traces a ray from a position in a direction
*
* @function Spring.TraceRayInDirection
*
* Returns all unit and/or feature hits along a ray, sorted by distance
* from the start position.
*
* @param posX number
* @param posY number
* @param posZ number
* @param dirX number
* @param dirY number
* @param dirZ number
* @param maxLength number
* @param type string Object type to test: `"unit"`, `"feature"`, or `"both"`
* @return table[] hits Array of `{hitLength, objectID, objectType}` entries
*/
int LuaSyncedRead::TraceRayInDirection(lua_State* L)
{
float3 pos(luaL_checkfloat(L, 1), luaL_checkfloat(L, 2), luaL_checkfloat(L, 3));
float3 dir(luaL_checkfloat(L, 4), luaL_checkfloat(L, 5), luaL_checkfloat(L, 6));
const float maxLen = luaL_optfloat(L, 7, 999999.f);
const char* type = luaL_checkstring(L, 8);
return TraceRayImpl(L, pos, dir, maxLen, type);
}

/*** Traces a ray between two positions
*
* @function Spring.TraceRayBetweenPositions
*
* Checks for unit and/or feature collisions between two positions
* and returns all hits sorted by distance from the start position.
*
* @param startX number
* @param startY number
* @param startZ number
* @param endX number
* @param endY number
* @param endZ number
* @param type string Object type to test: `"unit"`, `"feature"`, or `"both"`
* @return table[] hits Array of `{hitLength, objectID, objectType}` entries
*/
int LuaSyncedRead::TraceRayBetweenPositions(lua_State* L)
{
float3 start(luaL_checkfloat(L, 1), luaL_checkfloat(L, 2), luaL_checkfloat(L, 3));
float3 end(luaL_checkfloat(L, 4), luaL_checkfloat(L, 5), luaL_checkfloat(L, 6));
const char* type = luaL_checkstring(L, 7);
const auto [dir, length] = (end - start).GetNormalized();
return TraceRayImpl(L, start, dir, length, type);
}

static int TraceRayGroundImpl(lua_State *const L, const float3 &pos, const float3 &dir, const float maxLen, const bool testWater)
{
const float rayLength = CGround::LineGroundWaterCol(pos, dir, maxLen, testWater, CLuaHandle::GetHandleSynced(L));
Expand Down
5 changes: 2 additions & 3 deletions rts/Lua/LuaSyncedRead.h
Original file line number Diff line number Diff line change
Expand Up @@ -302,9 +302,8 @@ class LuaSyncedRead {

static int GetRadarErrorParams(lua_State* L);

static int TraceRay(lua_State* L); //TODO: not implemented
static int TraceRayUnits(lua_State* L); //TODO: not implemented
static int TraceRayFeatures(lua_State* L); //TODO: not implemented
static int TraceRayInDirection(lua_State* L);
static int TraceRayBetweenPositions(lua_State* L);
static int TraceRayGroundBetweenPositions(lua_State* L);
static int TraceRayGroundInDirection(lua_State* L);
};
Expand Down
Loading