00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025 #include <base/Expression>
00026
00027 #include <base/Matrix>
00028 #include <base/Math>
00029 #include <base/ConstantExpression>
00030 #include <base/VariableExpression>
00031 #include <base/SumExpression>
00032 #include <base/DifferenceExpression>
00033 #include <base/NegateExpression>
00034 #include <base/ProductExpression>
00035 #include <base/QuotientExpression>
00036 #include <base/SinExpression>
00037 #include <base/CosExpression>
00038
00039
00040 using base::Math;
00041 using base::Matrix;
00042 using base::Expression;
00043 using base::ExpressionNode;
00044 using base::ConstantExpression;
00045 using base::VariableExpression;
00046 using base::SumExpression;
00047 using base::DifferenceExpression;
00048 using base::NegateExpression;
00049 using base::ProductExpression;
00050 using base::QuotientExpression;
00051 using base::SinExpression;
00052 using base::CosExpression;
00053
00054
00055
00056 Expression::Expression()
00057 {
00058 expr = ref<ExpressionNode>(NewObj ConstantExpression(0));
00059 }
00060
00061 Expression::Expression(Real constant)
00062 {
00063 expr = ref<ExpressionNode>(NewObj ConstantExpression(constant));
00064 }
00065
00066 Expression::Expression(const Expression& e)
00067 {
00068 expr = e.expr;
00069 }
00070
00071 Expression::Expression(const String& exprString)
00072 {
00073 Int pos=0;
00074 expr = expression(exprString,pos).expr;
00075 }
00076
00077
00078
00079 base::Real Expression::evaluate(const Vector& params) const
00080 {
00081 expr->resetCache();
00082 return expr->evaluate(params);
00083 }
00084
00085 Expression Expression::differentiate( Expression withRespectTo ) const
00086 {
00087 if ( withRespectTo.expr->opType() != ExpressionNode::Variable )
00088 throw std::invalid_argument(Exception("Must pass a simple variable expression, such as Expression::p[2]"));
00089
00090 ref<VariableExpression> vexpr( narrow_ref<VariableExpression>(withRespectTo.expr) );
00091 return Expression(expr->differentiate(vexpr->index));
00092 }
00093
00094
00095 void Expression::simplify()
00096 {
00097 expr = simplifyConstantExpressions(expr);
00098 }
00099
00100 void Expression::operationCounts(Int& addsub, Int& multdiv, Int& trig) const
00101 {
00102 expr->operationCounts(addsub,multdiv,trig);
00103 }
00104
00105
00106 base::String Expression::toString() const
00107 {
00108 return expr->toString();
00109 }
00110
00111 Expression Expression::VariableIndexer::operator[](Int i) const
00112 {
00113 ref<ExpressionNode> expr(NewObj VariableExpression(i));
00114 return Expression(expr);
00115 }
00116
00117
00118 Expression::VariableIndexer Expression::p;
00119
00120
00121
00122
00123
00124
00125 bool Expression::peek(const String& s, Int pos, String next)
00126 {
00127 if (pos + next.size() > s.size()) return false;
00128
00129 return (s.find(next, pos) == pos);
00130 }
00131
00132 inline bool isAlpha(String::value_type c)
00133 {
00134 return ((c >= 'a') && (c <= 'z')) || ((c >= 'A') && (c <= 'Z'));
00135 }
00136
00137 inline bool isNum(String::value_type c)
00138 {
00139 return ((c >= '0') && (c <= '9'));
00140 }
00141
00142 inline bool isAlphaNum(String::value_type c)
00143 {
00144 return isAlpha(c) || isNum(c);
00145 }
00146
00147
00148 SInt Expression::index(const String& s, Int& pos)
00149 {
00150 SInt sign=1;
00151 if (peek(s,pos,'-')) {
00152 sign = -1;
00153 ++pos;
00154 }
00155 else
00156 if (peek(s,pos,'+')) ++pos;
00157
00158 Int v=0;
00159 while ( isNum(s[pos]) ) {
00160 v = (10*v) + Int(s[pos]-'0');
00161 ++pos;
00162 }
00163
00164 return SInt(sign*v);
00165 }
00166
00167
00168 Real Expression::real(const String& s, Int& pos)
00169 {
00170 Real sign=1.0;
00171 if (peek(s,pos,'-')) {
00172 sign = -1.0;
00173 ++pos;
00174 }
00175 else
00176 if (peek(s,pos,'+')) ++pos;
00177
00178 Real v=0;
00179 while ( isNum(s[pos]) ) {
00180 v = (10.0*v) + Int(s[pos]-'0');
00181 ++pos;
00182 }
00183 if (peek(s,pos,'.')) {
00184 ++pos;
00185 Real m=0.1;
00186 while (isNum(s[pos])) {
00187 v = v + Real(Int(s[pos]-'0'))*m;
00188 m *= 0.1;
00189 ++pos;
00190 }
00191 }
00192 if (peek(s,pos,'e') || peek(s,pos,'E')) {
00193 if (peek(s,pos+1,'-') || (peek(s,pos+1,'+'))) {
00194 ++pos;
00195 Real esign = peek(s,pos,'-')?-1.0:1.0;
00196 ++pos;
00197 Real exp=0;
00198 while ( isNum(s[pos]) ) {
00199 exp = (10.0*exp) + Int(s[pos]-'0');
00200 ++pos;
00201 }
00202 v = v * Math::pow(10.0,esign*exp);
00203 }
00204 }
00205
00206 return sign*v;
00207 }
00208
00209
00210 Expression Expression::expression(const String& s, Int& pos)
00211 {
00212 Expression lhs = term(s,pos);
00213 while (peek(s,pos,'+') || peek(s,pos,'-')) {
00214 String::value_type op = s[pos++];
00215 Expression rhs = term(s,pos);
00216 if (op == '+')
00217 lhs = lhs + rhs;
00218 else
00219 lhs = lhs - rhs;
00220 }
00221 return lhs;
00222 }
00223
00224
00225 Expression Expression::term(const String& s, Int& pos)
00226 {
00227 Expression lhs = prod(s, pos);
00228 while (peek(s,pos,'*') || peek(s,pos,'/')) {
00229 String::value_type op = s[pos++];
00230 Expression rhs = prod(s,pos);
00231 if (op == '*')
00232 lhs = lhs * rhs;
00233 else
00234 lhs = lhs / rhs;
00235 }
00236 return lhs;
00237 }
00238
00239
00240 Expression Expression::prod(const String& s, Int& pos)
00241 {
00242 Expression e;
00243 if (peek(s,pos,'(')) {
00244 ++pos;
00245 e = expression(s,pos);
00246 if (!peek(s,pos,')')) throw std::invalid_argument(Exception(String("expecting ')' but got '")+s[pos]+"' in expression"));
00247 ++pos;
00248 }
00249 else if (peek(s,pos,"p[")) {
00250 pos += 2;
00251 Int i = index(s,pos);
00252 if (!peek(s,pos,']')) throw std::invalid_argument(Exception(String("expecting ']' but got '")+s[pos]+"' after 'p[<index>' in expression"));
00253 ++pos;
00254 e = Expression::p[i];
00255 }
00256 else if (peek(s,pos,"cos(")) {
00257 pos += 4;
00258 Expression arg = expression(s,pos);
00259 e = base::cos(arg);
00260 if (!peek(s,pos,')')) throw std::invalid_argument(Exception(String("expecting ')' but got '")+s[pos]+"' after 'cos(<expression>' in expression"));
00261 ++pos;
00262 }
00263 else if (peek(s,pos,"sin(")) {
00264 pos += 4;
00265 Expression arg = expression(s,pos);
00266 e = base::sin(arg);
00267 if (!peek(s,pos,')')) throw std::invalid_argument(Exception(String("expecting ')' but got '")+s[pos]+"' after 'sin(<expression>' in expression"));
00268 ++pos;
00269 }
00270 else if (peek(s,pos,"tan(")) {
00271 pos += 4;
00272 Expression arg = expression(s,pos);
00273 e = base::sin(arg) / base::cos(arg);
00274 if (!peek(s,pos,')')) throw std::invalid_argument(Exception(String("expecting ')' but got '")+s[pos]+"' after 'tan(<expression>' in expression"));
00275 ++pos;
00276 }
00277 else if (peek(s,pos,"pi")) {
00278 pos += 2;
00279 e = Expression(consts::Pi);
00280 }
00281 else if (peek(s,pos,'s')) {
00282 ++pos;
00283 e = Expression::p[0];
00284 }
00285 else {
00286 Real v = real(s,pos);
00287 e = Expression(v);
00288 }
00289
00290 return e;
00291 }
00292
00293
00294
00295
00296 Expression& Expression::operator+=(const Expression& e)
00297 {
00298 expr = ref<ExpressionNode>(NewObj SumExpression(expr,e.expr));
00299 return *this;
00300 }
00301
00302 Expression& Expression::operator-=(const Expression& e)
00303 {
00304 expr = ref<ExpressionNode>(NewObj DifferenceExpression(expr,e.expr));
00305 return *this;
00306 }
00307
00308 Expression& Expression::operator*=(const Expression& e)
00309 {
00310 expr = ref<ExpressionNode>(NewObj ProductExpression(expr,e.expr));
00311 return *this;
00312 }
00313
00314 Expression& Expression::operator/=(const Expression& e)
00315 {
00316 expr = ref<ExpressionNode>(NewObj QuotientExpression(expr,e.expr));
00317 return *this;
00318 }
00319
00320 Expression& Expression::negate()
00321 {
00322 expr = ref<ExpressionNode>(NewObj NegateExpression(expr));
00323 return *this;
00324 }
00325
00326 Expression& Expression::sin()
00327 {
00328 expr = ref<ExpressionNode>(NewObj SinExpression(expr));
00329 return *this;
00330 }
00331
00332 Expression& Expression::cos()
00333 {
00334 expr = ref<ExpressionNode>(NewObj CosExpression(expr));
00335 return *this;
00336 }
00337
00338
00339 void Expression::serialize(Serializer& s)
00340 {
00341
00342 Serializable::registerSerializableInstantiator<ExpressionNode,SumExpression>(sumInstantiator);
00343 Serializable::registerSerializableInstantiator<ExpressionNode,DifferenceExpression>(differenceInstantiator);
00344 Serializable::registerSerializableInstantiator<ExpressionNode,ProductExpression>(productInstantiator);
00345 Serializable::registerSerializableInstantiator<ExpressionNode,QuotientExpression>(quotientInstantiator);
00346 Serializable::registerSerializableInstantiator<ExpressionNode,NegateExpression>(negateInstantiator);
00347 Serializable::registerSerializableInstantiator<ExpressionNode,ConstantExpression>(constantInstantiator);
00348 Serializable::registerSerializableInstantiator<ExpressionNode,VariableExpression>(variableInstantiator);
00349 Serializable::registerSerializableInstantiator<ExpressionNode,SinExpression>(sinInstantiator);
00350 Serializable::registerSerializableInstantiator<ExpressionNode,CosExpression>(cosInstantiator);
00351 s.comment(String("Symbolic expression: "+toString()));
00352 s.baseRef(expr,"expression");
00353 }
00354
00355
00356
00357
00358
00359 ref<ExpressionNode> Expression::simplifyConstantExpressions(ref<ExpressionNode> expr)
00360 {
00361 if (!expr) { Assert(expr); }
00362
00363 ref<ExpressionNode> sexpr(expr);
00364
00365 if (sexpr->isBinaryOp()) {
00366 ref<BinaryOpExpression> binExpr(narrow_ref<BinaryOpExpression>(sexpr));
00367
00368
00369 binExpr->leftArg = simplifyConstantExpressions(binExpr->leftArg);
00370 binExpr->rightArg = simplifyConstantExpressions(binExpr->rightArg);
00371
00372
00373 bool leftIsConst = binExpr->leftArg->opType() == ExpressionNode::Constant;
00374 bool rightIsConst = binExpr->rightArg->opType() == ExpressionNode::Constant;
00375 ref<ConstantExpression> leftConst;
00376 if (leftIsConst) leftConst = narrow_ref<ConstantExpression>(binExpr->leftArg);
00377 bool leftIsZero = leftIsConst?(Math::equals(leftConst->constValue,0)):false;
00378 ref<ConstantExpression> rightConst;
00379 if (rightIsConst) rightConst = narrow_ref<ConstantExpression>(binExpr->rightArg);
00380 bool rightIsZero = rightIsConst?(Math::equals(rightConst->constValue,0)):false;
00381
00382
00383
00384
00385
00386 if (binExpr->opType() == ExpressionNode::Sum) {
00387
00388 if (leftIsConst) {
00389 if (rightIsConst)
00390 sexpr = ref<ExpressionNode>(NewObj ConstantExpression(leftConst->constValue + rightConst->constValue));
00391 else
00392 if (leftIsZero)
00393 sexpr = binExpr->rightArg;
00394 }
00395 else
00396 if (rightIsConst && rightIsZero)
00397 sexpr = binExpr->leftArg;
00398
00399 }
00400 else if (binExpr->opType() == ExpressionNode::Difference) {
00401
00402 if (leftIsConst) {
00403 if (rightIsConst)
00404 sexpr = ref<ExpressionNode>(NewObj ConstantExpression(leftConst->constValue - rightConst->constValue));
00405 else
00406 if (leftIsZero)
00407 sexpr = ref<ExpressionNode>(NewObj NegateExpression(binExpr->rightArg));
00408 }
00409 else
00410 if (rightIsConst && rightIsZero)
00411 sexpr = binExpr->leftArg;
00412
00413 }
00414 else if (binExpr->opType() == ExpressionNode::Product) {
00415
00416 bool leftIsOne = leftIsConst?(Math::equals(leftConst->constValue,1)):false;
00417 bool leftIsMinusOne = leftIsConst?(Math::equals(leftConst->constValue,-1)):false;
00418 bool rightIsOne = rightIsConst?(Math::equals(rightConst->constValue,1)):false;
00419 bool rightIsMinusOne = rightIsConst?(Math::equals(rightConst->constValue,-1)):false;
00420
00421 if (leftIsConst) {
00422 if (rightIsConst)
00423 sexpr = ref<ExpressionNode>(NewObj ConstantExpression(leftConst->constValue * rightConst->constValue));
00424 else {
00425 if (leftIsZero)
00426 sexpr = ref<ExpressionNode>(NewObj ConstantExpression(0));
00427 else
00428 if (leftIsOne)
00429 sexpr = binExpr->rightArg;
00430 else
00431 if (leftIsMinusOne) {
00432 if ( binExpr->rightArg->opType() == ExpressionNode::Negative)
00433 sexpr = narrow_ref<NegateExpression>( binExpr->rightArg )->arg;
00434 else sexpr = ref<NegateExpression>( NewObj NegateExpression( binExpr->rightArg ));
00435 }
00436
00437 }
00438 }
00439 else
00440 if (rightIsConst) {
00441 if (rightIsZero)
00442 sexpr = ref<ExpressionNode>(NewObj ConstantExpression(0));
00443 else
00444 if (rightIsOne)
00445 sexpr = binExpr->leftArg;
00446 else
00447 if (rightIsMinusOne) {
00448 if ( binExpr->leftArg->opType() == ExpressionNode::Negative )
00449 sexpr = narrow_ref<NegateExpression>( binExpr->leftArg )->arg;
00450 else sexpr = ref<NegateExpression>( NewObj NegateExpression( binExpr->leftArg ));
00451 }
00452 }
00453
00454 }
00455 else if (binExpr->opType() == ExpressionNode::Quotient) {
00456
00457 bool rightIsOne = rightIsConst?(equals(rightConst->constValue,1)):false;
00458
00459 if (leftIsConst) {
00460 if (rightIsConst)
00461 sexpr = ref<ExpressionNode>(NewObj ConstantExpression(leftConst->constValue / rightConst->constValue));
00462 else {
00463 if (leftIsZero)
00464 sexpr = ref<ExpressionNode>(NewObj ConstantExpression(0));
00465 }
00466
00467 }
00468 else
00469 if (rightIsConst) {
00470 if (rightIsZero) {
00471 throw std::out_of_range(Exception("cannot divide by constant 0"));
00472 }
00473 else
00474 if (rightIsOne)
00475 sexpr = binExpr->leftArg;
00476 }
00477
00478 }
00479
00480 }
00481 else if (sexpr->isUnaryOp()) {
00482 ref<UnaryOpExpression> unaryExpr(narrow_ref<UnaryOpExpression>(sexpr));
00483
00484
00485 unaryExpr->arg = simplifyConstantExpressions(unaryExpr->arg);
00486
00487
00488 bool isConst = unaryExpr->arg->opType() == ExpressionNode::Constant;
00489 ref<ConstantExpression> argConst;
00490 if (isConst) argConst = narrow_ref<ConstantExpression>(unaryExpr->arg);
00491
00492 if (unaryExpr->opType() == ExpressionNode::Sine) {
00493 if (isConst)
00494 sexpr = ref<ExpressionNode>(NewObj ConstantExpression(Math::sin(argConst->constValue)));
00495 }
00496 else if (unaryExpr->opType() == ExpressionNode::Cosine) {
00497 if (isConst)
00498 sexpr = ref<ExpressionNode>(NewObj ConstantExpression(Math::cos(argConst->constValue)));
00499 }
00500 else if (unaryExpr->opType() == ExpressionNode::Negative) {
00501 ref<NegateExpression> nexpr( narrow_ref<NegateExpression>(unaryExpr) );
00502 if (nexpr->arg->opType() == ExpressionNode::Constant) {
00503 bool isZero = Math::equals(ref<ConstantExpression>(narrow_ref<ConstantExpression>(nexpr->arg))->constValue,0);
00504 if (isZero)
00505 sexpr = nexpr->arg;
00506 }
00507 else if (nexpr->arg->opType() == ExpressionNode::Negative) {
00508 sexpr = narrow_ref<NegateExpression>(nexpr->arg)->arg;
00509 }
00510 }
00511
00512 }
00513
00514 return sexpr;
00515 }
00516
00517
00518
00519
00520 void base::simplify( ExpressionMatrix& m )
00521 {
00522 for ( Int r=0; r < m.size1(); r++ ) {
00523 for ( Int c=0; c < m.size2(); c++ ) {
00524 m(r,c).simplify();
00525 }
00526 }
00527 return;
00528 }
00529
00530
00531
00532 base::Matrix base::evaluate( const ExpressionMatrix& m, const Vector& params )
00533 {
00534 Matrix em(m.size1(), m.size2());
00535 for ( Int r=0; r < m.size1(); r++ )
00536 for ( Int c=0; c < m.size2(); c++ )
00537 em(r,c) = m(r,c).evaluate(params);
00538 return em;
00539 }
00540
00541
00542 base::ExpressionMatrix base::toExpressionMatrix(const Matrix& m)
00543 {
00544 ExpressionMatrix em(m.size1(), m.size2());
00545 for( Int r=0; r < m.size1(); r++ )
00546 for ( Int c=0; c < m.size2(); c++ )
00547 em(r,c)=m(r,c);
00548 return em;
00549 }